bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 Callable, 17 ClassVar, 18 Dict, 19 Generic, 20 List, 21 Literal, 22 Mapping, 23 NamedTuple, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 Type, 29 TypeVar, 30 Union, 31 cast, 32) 33 34import numpy as np 35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 37from loguru import logger 38from numpy.typing import NDArray 39from pydantic import ( 40 AfterValidator, 41 Discriminator, 42 Field, 43 RootModel, 44 SerializationInfo, 45 SerializerFunctionWrapHandler, 46 StrictInt, 47 Tag, 48 ValidationInfo, 49 WrapSerializer, 50 field_validator, 51 model_serializer, 52 model_validator, 53) 54from typing_extensions import Annotated, Self, assert_never, get_args 55 56from .._internal.common_nodes import ( 57 InvalidDescr, 58 Node, 59 NodeWithExplicitlySetFields, 60) 61from .._internal.constants import DTYPE_LIMITS 62from .._internal.field_warning import issue_warning, warn 63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 64from .._internal.io import FileDescr as FileDescr 65from .._internal.io import ( 66 FileSource, 67 WithSuffix, 68 YamlValue, 69 extract_file_name, 70 get_reader, 71 wo_special_file_name, 72) 73from .._internal.io_basics import Sha256 as Sha256 74from .._internal.io_packaging import ( 75 FileDescr_, 76 FileSource_, 77 package_file_descr_serializer, 78) 79from .._internal.io_utils import load_array 80from .._internal.node_converter import Converter 81from .._internal.type_guards import is_dict, is_sequence 82from .._internal.types import ( 83 FAIR, 84 AbsoluteTolerance, 85 LowerCaseIdentifier, 86 LowerCaseIdentifierAnno, 87 MismatchedElementsPerMillion, 88 RelativeTolerance, 89) 90from .._internal.types import Datetime as Datetime 91from .._internal.types import Identifier as Identifier 92from .._internal.types import NotEmpty as NotEmpty 93from .._internal.types import SiUnit as SiUnit 94from .._internal.url import HttpUrl as HttpUrl 95from .._internal.validation_context import get_validation_context 96from .._internal.validator_annotations import RestrictCharacters 97from .._internal.version_type import Version as Version 98from .._internal.warning_levels import INFO 99from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 100from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 101from ..dataset.v0_3 import DatasetDescr as DatasetDescr 102from ..dataset.v0_3 import DatasetId as DatasetId 103from ..dataset.v0_3 import LinkedDataset as LinkedDataset 104from ..dataset.v0_3 import Uploader as Uploader 105from ..generic.v0_3 import ( 106 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 107) 108from ..generic.v0_3 import Author as Author 109from ..generic.v0_3 import BadgeDescr as BadgeDescr 110from ..generic.v0_3 import CiteEntry as CiteEntry 111from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 112from ..generic.v0_3 import Doi as Doi 113from ..generic.v0_3 import ( 114 FileSource_documentation, 115 GenericModelDescrBase, 116 LinkedResourceBase, 117 _author_conv, # pyright: ignore[reportPrivateUsage] 118 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 119) 120from ..generic.v0_3 import LicenseId as LicenseId 121from ..generic.v0_3 import LinkedResource as LinkedResource 122from ..generic.v0_3 import Maintainer as Maintainer 123from ..generic.v0_3 import OrcidId as OrcidId 124from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 125from ..generic.v0_3 import ResourceId as ResourceId 126from .v0_4 import Author as _Author_v0_4 127from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 128from .v0_4 import CallableFromDepencency as CallableFromDepencency 129from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 130from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 131from .v0_4 import ClipDescr as _ClipDescr_v0_4 132from .v0_4 import ClipKwargs as ClipKwargs 133from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 134from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 135from .v0_4 import KnownRunMode as KnownRunMode 136from .v0_4 import ModelDescr as _ModelDescr_v0_4 137from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 138from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 139from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 140from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 141from .v0_4 import ProcessingKwargs as ProcessingKwargs 142from .v0_4 import RunMode as RunMode 143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 147from .v0_4 import TensorName as _TensorName_v0_4 148from .v0_4 import WeightsFormat as WeightsFormat 149from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 150from .v0_4 import package_weights 151 152SpaceUnit = Literal[ 153 "attometer", 154 "angstrom", 155 "centimeter", 156 "decimeter", 157 "exameter", 158 "femtometer", 159 "foot", 160 "gigameter", 161 "hectometer", 162 "inch", 163 "kilometer", 164 "megameter", 165 "meter", 166 "micrometer", 167 "mile", 168 "millimeter", 169 "nanometer", 170 "parsec", 171 "petameter", 172 "picometer", 173 "terameter", 174 "yard", 175 "yoctometer", 176 "yottameter", 177 "zeptometer", 178 "zettameter", 179] 180"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 181 182TimeUnit = Literal[ 183 "attosecond", 184 "centisecond", 185 "day", 186 "decisecond", 187 "exasecond", 188 "femtosecond", 189 "gigasecond", 190 "hectosecond", 191 "hour", 192 "kilosecond", 193 "megasecond", 194 "microsecond", 195 "millisecond", 196 "minute", 197 "nanosecond", 198 "petasecond", 199 "picosecond", 200 "second", 201 "terasecond", 202 "yoctosecond", 203 "yottasecond", 204 "zeptosecond", 205 "zettasecond", 206] 207"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 208 209AxisType = Literal["batch", "channel", "index", "time", "space"] 210 211_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 212 "b": "batch", 213 "t": "time", 214 "i": "index", 215 "c": "channel", 216 "x": "space", 217 "y": "space", 218 "z": "space", 219} 220 221_AXIS_ID_MAP = { 222 "b": "batch", 223 "t": "time", 224 "i": "index", 225 "c": "channel", 226} 227 228 229class TensorId(LowerCaseIdentifier): 230 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 231 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 232 ] 233 234 235def _normalize_axis_id(a: str): 236 a = str(a) 237 normalized = _AXIS_ID_MAP.get(a, a) 238 if a != normalized: 239 logger.opt(depth=3).warning( 240 "Normalized axis id from '{}' to '{}'.", a, normalized 241 ) 242 return normalized 243 244 245class AxisId(LowerCaseIdentifier): 246 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 247 Annotated[ 248 LowerCaseIdentifierAnno, 249 MaxLen(16), 250 AfterValidator(_normalize_axis_id), 251 ] 252 ] 253 254 255def _is_batch(a: str) -> bool: 256 return str(a) == "batch" 257 258 259def _is_not_batch(a: str) -> bool: 260 return not _is_batch(a) 261 262 263NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 264 265PreprocessingId = Literal[ 266 "binarize", 267 "clip", 268 "ensure_dtype", 269 "fixed_zero_mean_unit_variance", 270 "scale_linear", 271 "scale_range", 272 "sigmoid", 273 "softmax", 274] 275PostprocessingId = Literal[ 276 "binarize", 277 "clip", 278 "ensure_dtype", 279 "fixed_zero_mean_unit_variance", 280 "scale_linear", 281 "scale_mean_variance", 282 "scale_range", 283 "sigmoid", 284 "softmax", 285 "zero_mean_unit_variance", 286] 287 288 289SAME_AS_TYPE = "<same as type>" 290 291 292ParameterizedSize_N = int 293""" 294Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 295""" 296 297 298class ParameterizedSize(Node): 299 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 300 301 - **min** and **step** are given by the model description. 302 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 303 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 304 This allows to adjust the axis size more generically. 305 """ 306 307 N: ClassVar[Type[int]] = ParameterizedSize_N 308 """Positive integer to parameterize this axis""" 309 310 min: Annotated[int, Gt(0)] 311 step: Annotated[int, Gt(0)] 312 313 def validate_size(self, size: int) -> int: 314 if size < self.min: 315 raise ValueError(f"size {size} < {self.min}") 316 if (size - self.min) % self.step != 0: 317 raise ValueError( 318 f"axis of size {size} is not parameterized by `min + n*step` =" 319 + f" `{self.min} + n*{self.step}`" 320 ) 321 322 return size 323 324 def get_size(self, n: ParameterizedSize_N) -> int: 325 return self.min + self.step * n 326 327 def get_n(self, s: int) -> ParameterizedSize_N: 328 """return smallest n parameterizing a size greater or equal than `s`""" 329 return ceil((s - self.min) / self.step) 330 331 332class DataDependentSize(Node): 333 min: Annotated[int, Gt(0)] = 1 334 max: Annotated[Optional[int], Gt(1)] = None 335 336 @model_validator(mode="after") 337 def _validate_max_gt_min(self): 338 if self.max is not None and self.min >= self.max: 339 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 340 341 return self 342 343 def validate_size(self, size: int) -> int: 344 if size < self.min: 345 raise ValueError(f"size {size} < {self.min}") 346 347 if self.max is not None and size > self.max: 348 raise ValueError(f"size {size} > {self.max}") 349 350 return size 351 352 353class SizeReference(Node): 354 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 355 356 `axis.size = reference.size * reference.scale / axis.scale + offset` 357 358 Note: 359 1. The axis and the referenced axis need to have the same unit (or no unit). 360 2. Batch axes may not be referenced. 361 3. Fractions are rounded down. 362 4. If the reference axis is `concatenable` the referencing axis is assumed to be 363 `concatenable` as well with the same block order. 364 365 Example: 366 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 367 Let's assume that we want to express the image height h in relation to its width w 368 instead of only accepting input images of exactly 100*49 pixels 369 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 370 371 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 372 >>> h = SpaceInputAxis( 373 ... id=AxisId("h"), 374 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 375 ... unit="millimeter", 376 ... scale=4, 377 ... ) 378 >>> print(h.size.get_size(h, w)) 379 49 380 381 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 382 """ 383 384 tensor_id: TensorId 385 """tensor id of the reference axis""" 386 387 axis_id: AxisId 388 """axis id of the reference axis""" 389 390 offset: StrictInt = 0 391 392 def get_size( 393 self, 394 axis: Union[ 395 ChannelAxis, 396 IndexInputAxis, 397 IndexOutputAxis, 398 TimeInputAxis, 399 SpaceInputAxis, 400 TimeOutputAxis, 401 TimeOutputAxisWithHalo, 402 SpaceOutputAxis, 403 SpaceOutputAxisWithHalo, 404 ], 405 ref_axis: Union[ 406 ChannelAxis, 407 IndexInputAxis, 408 IndexOutputAxis, 409 TimeInputAxis, 410 SpaceInputAxis, 411 TimeOutputAxis, 412 TimeOutputAxisWithHalo, 413 SpaceOutputAxis, 414 SpaceOutputAxisWithHalo, 415 ], 416 n: ParameterizedSize_N = 0, 417 ref_size: Optional[int] = None, 418 ): 419 """Compute the concrete size for a given axis and its reference axis. 420 421 Args: 422 axis: The axis this `SizeReference` is the size of. 423 ref_axis: The reference axis to compute the size from. 424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 425 and no fixed **ref_size** is given, 426 **n** is used to compute the size of the parameterized **ref_axis**. 427 ref_size: Overwrite the reference size instead of deriving it from 428 **ref_axis** 429 (**ref_axis.scale** is still used; any given **n** is ignored). 430 """ 431 assert axis.size == self, ( 432 "Given `axis.size` is not defined by this `SizeReference`" 433 ) 434 435 assert ref_axis.id == self.axis_id, ( 436 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 437 ) 438 439 assert axis.unit == ref_axis.unit, ( 440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 441 f" but {axis.unit}!={ref_axis.unit}" 442 ) 443 if ref_size is None: 444 if isinstance(ref_axis.size, (int, float)): 445 ref_size = ref_axis.size 446 elif isinstance(ref_axis.size, ParameterizedSize): 447 ref_size = ref_axis.size.get_size(n) 448 elif isinstance(ref_axis.size, DataDependentSize): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 451 ) 452 elif isinstance(ref_axis.size, SizeReference): 453 raise ValueError( 454 "Reference axis referenced in `SizeReference` may not be sized by a" 455 + " `SizeReference` itself." 456 ) 457 else: 458 assert_never(ref_axis.size) 459 460 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 461 462 @staticmethod 463 def _get_unit( 464 axis: Union[ 465 ChannelAxis, 466 IndexInputAxis, 467 IndexOutputAxis, 468 TimeInputAxis, 469 SpaceInputAxis, 470 TimeOutputAxis, 471 TimeOutputAxisWithHalo, 472 SpaceOutputAxis, 473 SpaceOutputAxisWithHalo, 474 ], 475 ): 476 return axis.unit 477 478 479class AxisBase(NodeWithExplicitlySetFields): 480 id: AxisId 481 """An axis id unique across all axes of one tensor.""" 482 483 description: Annotated[str, MaxLen(128)] = "" 484 """A short description of this axis beyond its type and id.""" 485 486 487class WithHalo(Node): 488 halo: Annotated[int, Ge(1)] 489 """The halo should be cropped from the output tensor to avoid boundary effects. 490 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 491 To document a halo that is already cropped by the model use `size.offset` instead.""" 492 493 size: Annotated[ 494 SizeReference, 495 Field( 496 examples=[ 497 10, 498 SizeReference( 499 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 500 ).model_dump(mode="json"), 501 ] 502 ), 503 ] 504 """reference to another axis with an optional offset (see `SizeReference`)""" 505 506 507BATCH_AXIS_ID = AxisId("batch") 508 509 510class BatchAxis(AxisBase): 511 implemented_type: ClassVar[Literal["batch"]] = "batch" 512 if TYPE_CHECKING: 513 type: Literal["batch"] = "batch" 514 else: 515 type: Literal["batch"] 516 517 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 518 size: Optional[Literal[1]] = None 519 """The batch size may be fixed to 1, 520 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 521 522 @property 523 def scale(self): 524 return 1.0 525 526 @property 527 def concatenable(self): 528 return True 529 530 @property 531 def unit(self): 532 return None 533 534 535class ChannelAxis(AxisBase): 536 implemented_type: ClassVar[Literal["channel"]] = "channel" 537 if TYPE_CHECKING: 538 type: Literal["channel"] = "channel" 539 else: 540 type: Literal["channel"] 541 542 id: NonBatchAxisId = AxisId("channel") 543 544 channel_names: NotEmpty[List[Identifier]] 545 546 @property 547 def size(self) -> int: 548 return len(self.channel_names) 549 550 @property 551 def concatenable(self): 552 return False 553 554 @property 555 def scale(self) -> float: 556 return 1.0 557 558 @property 559 def unit(self): 560 return None 561 562 563class IndexAxisBase(AxisBase): 564 implemented_type: ClassVar[Literal["index"]] = "index" 565 if TYPE_CHECKING: 566 type: Literal["index"] = "index" 567 else: 568 type: Literal["index"] 569 570 id: NonBatchAxisId = AxisId("index") 571 572 @property 573 def scale(self) -> float: 574 return 1.0 575 576 @property 577 def unit(self): 578 return None 579 580 581class _WithInputAxisSize(Node): 582 size: Annotated[ 583 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 584 Field( 585 examples=[ 586 10, 587 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 588 SizeReference( 589 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 590 ).model_dump(mode="json"), 591 ] 592 ), 593 ] 594 """The size/length of this axis can be specified as 595 - fixed integer 596 - parameterized series of valid sizes (`ParameterizedSize`) 597 - reference to another axis with an optional offset (`SizeReference`) 598 """ 599 600 601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 602 concatenable: bool = False 603 """If a model has a `concatenable` input axis, it can be processed blockwise, 604 splitting a longer sample axis into blocks matching its input tensor description. 605 Output axes are concatenable if they have a `SizeReference` to a concatenable 606 input axis. 607 """ 608 609 610class IndexOutputAxis(IndexAxisBase): 611 size: Annotated[ 612 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 613 Field( 614 examples=[ 615 10, 616 SizeReference( 617 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 618 ).model_dump(mode="json"), 619 ] 620 ), 621 ] 622 """The size/length of this axis can be specified as 623 - fixed integer 624 - reference to another axis with an optional offset (`SizeReference`) 625 - data dependent size using `DataDependentSize` (size is only known after model inference) 626 """ 627 628 629class TimeAxisBase(AxisBase): 630 implemented_type: ClassVar[Literal["time"]] = "time" 631 if TYPE_CHECKING: 632 type: Literal["time"] = "time" 633 else: 634 type: Literal["time"] 635 636 id: NonBatchAxisId = AxisId("time") 637 unit: Optional[TimeUnit] = None 638 scale: Annotated[float, Gt(0)] = 1.0 639 640 641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 642 concatenable: bool = False 643 """If a model has a `concatenable` input axis, it can be processed blockwise, 644 splitting a longer sample axis into blocks matching its input tensor description. 645 Output axes are concatenable if they have a `SizeReference` to a concatenable 646 input axis. 647 """ 648 649 650class SpaceAxisBase(AxisBase): 651 implemented_type: ClassVar[Literal["space"]] = "space" 652 if TYPE_CHECKING: 653 type: Literal["space"] = "space" 654 else: 655 type: Literal["space"] 656 657 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 658 unit: Optional[SpaceUnit] = None 659 scale: Annotated[float, Gt(0)] = 1.0 660 661 662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 663 concatenable: bool = False 664 """If a model has a `concatenable` input axis, it can be processed blockwise, 665 splitting a longer sample axis into blocks matching its input tensor description. 666 Output axes are concatenable if they have a `SizeReference` to a concatenable 667 input axis. 668 """ 669 670 671INPUT_AXIS_TYPES = ( 672 BatchAxis, 673 ChannelAxis, 674 IndexInputAxis, 675 TimeInputAxis, 676 SpaceInputAxis, 677) 678"""intended for isinstance comparisons in py<3.10""" 679 680_InputAxisUnion = Union[ 681 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 682] 683InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 684 685 686class _WithOutputAxisSize(Node): 687 size: Annotated[ 688 Union[Annotated[int, Gt(0)], SizeReference], 689 Field( 690 examples=[ 691 10, 692 SizeReference( 693 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 694 ).model_dump(mode="json"), 695 ] 696 ), 697 ] 698 """The size/length of this axis can be specified as 699 - fixed integer 700 - reference to another axis with an optional offset (see `SizeReference`) 701 """ 702 703 704class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 705 pass 706 707 708class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 709 pass 710 711 712def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 713 if isinstance(v, dict): 714 return "with_halo" if "halo" in v else "wo_halo" 715 else: 716 return "with_halo" if hasattr(v, "halo") else "wo_halo" 717 718 719_TimeOutputAxisUnion = Annotated[ 720 Union[ 721 Annotated[TimeOutputAxis, Tag("wo_halo")], 722 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 723 ], 724 Discriminator(_get_halo_axis_discriminator_value), 725] 726 727 728class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 729 pass 730 731 732class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 733 pass 734 735 736_SpaceOutputAxisUnion = Annotated[ 737 Union[ 738 Annotated[SpaceOutputAxis, Tag("wo_halo")], 739 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 740 ], 741 Discriminator(_get_halo_axis_discriminator_value), 742] 743 744 745_OutputAxisUnion = Union[ 746 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 747] 748OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 749 750OUTPUT_AXIS_TYPES = ( 751 BatchAxis, 752 ChannelAxis, 753 IndexOutputAxis, 754 TimeOutputAxis, 755 TimeOutputAxisWithHalo, 756 SpaceOutputAxis, 757 SpaceOutputAxisWithHalo, 758) 759"""intended for isinstance comparisons in py<3.10""" 760 761 762AnyAxis = Union[InputAxis, OutputAxis] 763 764ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 765"""intended for isinstance comparisons in py<3.10""" 766 767TVs = Union[ 768 NotEmpty[List[int]], 769 NotEmpty[List[float]], 770 NotEmpty[List[bool]], 771 NotEmpty[List[str]], 772] 773 774 775NominalOrOrdinalDType = Literal[ 776 "float32", 777 "float64", 778 "uint8", 779 "int8", 780 "uint16", 781 "int16", 782 "uint32", 783 "int32", 784 "uint64", 785 "int64", 786 "bool", 787] 788 789 790class NominalOrOrdinalDataDescr(Node): 791 values: TVs 792 """A fixed set of nominal or an ascending sequence of ordinal values. 793 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 794 String `values` are interpreted as labels for tensor values 0, ..., N. 795 Note: as YAML 1.2 does not natively support a "set" datatype, 796 nominal values should be given as a sequence (aka list/array) as well. 797 """ 798 799 type: Annotated[ 800 NominalOrOrdinalDType, 801 Field( 802 examples=[ 803 "float32", 804 "uint8", 805 "uint16", 806 "int64", 807 "bool", 808 ], 809 ), 810 ] = "uint8" 811 812 @model_validator(mode="after") 813 def _validate_values_match_type( 814 self, 815 ) -> Self: 816 incompatible: List[Any] = [] 817 for v in self.values: 818 if self.type == "bool": 819 if not isinstance(v, bool): 820 incompatible.append(v) 821 elif self.type in DTYPE_LIMITS: 822 if ( 823 isinstance(v, (int, float)) 824 and ( 825 v < DTYPE_LIMITS[self.type].min 826 or v > DTYPE_LIMITS[self.type].max 827 ) 828 or (isinstance(v, str) and "uint" not in self.type) 829 or (isinstance(v, float) and "int" in self.type) 830 ): 831 incompatible.append(v) 832 else: 833 incompatible.append(v) 834 835 if len(incompatible) == 5: 836 incompatible.append("...") 837 break 838 839 if incompatible: 840 raise ValueError( 841 f"data type '{self.type}' incompatible with values {incompatible}" 842 ) 843 844 return self 845 846 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 847 848 @property 849 def range(self): 850 if isinstance(self.values[0], str): 851 return 0, len(self.values) - 1 852 else: 853 return min(self.values), max(self.values) 854 855 856IntervalOrRatioDType = Literal[ 857 "float32", 858 "float64", 859 "uint8", 860 "int8", 861 "uint16", 862 "int16", 863 "uint32", 864 "int32", 865 "uint64", 866 "int64", 867] 868 869 870class IntervalOrRatioDataDescr(Node): 871 type: Annotated[ # TODO: rename to dtype 872 IntervalOrRatioDType, 873 Field( 874 examples=["float32", "float64", "uint8", "uint16"], 875 ), 876 ] = "float32" 877 range: Tuple[Optional[float], Optional[float]] = ( 878 None, 879 None, 880 ) 881 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 882 `None` corresponds to min/max of what can be expressed by **type**.""" 883 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 884 scale: float = 1.0 885 """Scale for data on an interval (or ratio) scale.""" 886 offset: Optional[float] = None 887 """Offset for data on a ratio scale.""" 888 889 @model_validator(mode="before") 890 def _replace_inf(cls, data: Any): 891 if is_dict(data): 892 if "range" in data and is_sequence(data["range"]): 893 forbidden = ( 894 "inf", 895 "-inf", 896 ".inf", 897 "-.inf", 898 float("inf"), 899 float("-inf"), 900 ) 901 if any(v in forbidden for v in data["range"]): 902 issue_warning("replaced 'inf' value", value=data["range"]) 903 904 data["range"] = tuple( 905 (None if v in forbidden else v) for v in data["range"] 906 ) 907 908 return data 909 910 911TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 912 913 914class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 915 """processing base class""" 916 917 918class BinarizeKwargs(ProcessingKwargs): 919 """key word arguments for `BinarizeDescr`""" 920 921 threshold: float 922 """The fixed threshold""" 923 924 925class BinarizeAlongAxisKwargs(ProcessingKwargs): 926 """key word arguments for `BinarizeDescr`""" 927 928 threshold: NotEmpty[List[float]] 929 """The fixed threshold values along `axis`""" 930 931 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 932 """The `threshold` axis""" 933 934 935class BinarizeDescr(ProcessingDescrBase): 936 """Binarize the tensor with a fixed threshold. 937 938 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 939 will be set to one, values below the threshold to zero. 940 941 Examples: 942 - in YAML 943 ```yaml 944 postprocessing: 945 - id: binarize 946 kwargs: 947 axis: 'channel' 948 threshold: [0.25, 0.5, 0.75] 949 ``` 950 - in Python: 951 >>> postprocessing = [BinarizeDescr( 952 ... kwargs=BinarizeAlongAxisKwargs( 953 ... axis=AxisId('channel'), 954 ... threshold=[0.25, 0.5, 0.75], 955 ... ) 956 ... )] 957 """ 958 959 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 960 if TYPE_CHECKING: 961 id: Literal["binarize"] = "binarize" 962 else: 963 id: Literal["binarize"] 964 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 965 966 967class ClipDescr(ProcessingDescrBase): 968 """Set tensor values below min to min and above max to max. 969 970 See `ScaleRangeDescr` for examples. 971 """ 972 973 implemented_id: ClassVar[Literal["clip"]] = "clip" 974 if TYPE_CHECKING: 975 id: Literal["clip"] = "clip" 976 else: 977 id: Literal["clip"] 978 979 kwargs: ClipKwargs 980 981 982class EnsureDtypeKwargs(ProcessingKwargs): 983 """key word arguments for `EnsureDtypeDescr`""" 984 985 dtype: Literal[ 986 "float32", 987 "float64", 988 "uint8", 989 "int8", 990 "uint16", 991 "int16", 992 "uint32", 993 "int32", 994 "uint64", 995 "int64", 996 "bool", 997 ] 998 999 1000class EnsureDtypeDescr(ProcessingDescrBase): 1001 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1002 1003 This can for example be used to ensure the inner neural network model gets a 1004 different input tensor data type than the fully described bioimage.io model does. 1005 1006 Examples: 1007 The described bioimage.io model (incl. preprocessing) accepts any 1008 float32-compatible tensor, normalizes it with percentiles and clipping and then 1009 casts it to uint8, which is what the neural network in this example expects. 1010 - in YAML 1011 ```yaml 1012 inputs: 1013 - data: 1014 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1015 preprocessing: 1016 - id: scale_range 1017 kwargs: 1018 axes: ['y', 'x'] 1019 max_percentile: 99.8 1020 min_percentile: 5.0 1021 - id: clip 1022 kwargs: 1023 min: 0.0 1024 max: 1.0 1025 - id: ensure_dtype # the neural network of the model requires uint8 1026 kwargs: 1027 dtype: uint8 1028 ``` 1029 - in Python: 1030 >>> preprocessing = [ 1031 ... ScaleRangeDescr( 1032 ... kwargs=ScaleRangeKwargs( 1033 ... axes= (AxisId('y'), AxisId('x')), 1034 ... max_percentile= 99.8, 1035 ... min_percentile= 5.0, 1036 ... ) 1037 ... ), 1038 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1039 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1040 ... ] 1041 """ 1042 1043 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1044 if TYPE_CHECKING: 1045 id: Literal["ensure_dtype"] = "ensure_dtype" 1046 else: 1047 id: Literal["ensure_dtype"] 1048 1049 kwargs: EnsureDtypeKwargs 1050 1051 1052class ScaleLinearKwargs(ProcessingKwargs): 1053 """Key word arguments for `ScaleLinearDescr`""" 1054 1055 gain: float = 1.0 1056 """multiplicative factor""" 1057 1058 offset: float = 0.0 1059 """additive term""" 1060 1061 @model_validator(mode="after") 1062 def _validate(self) -> Self: 1063 if self.gain == 1.0 and self.offset == 0.0: 1064 raise ValueError( 1065 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1066 + " != 0.0." 1067 ) 1068 1069 return self 1070 1071 1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1073 """Key word arguments for `ScaleLinearDescr`""" 1074 1075 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1076 """The axis of gain and offset values.""" 1077 1078 gain: Union[float, NotEmpty[List[float]]] = 1.0 1079 """multiplicative factor""" 1080 1081 offset: Union[float, NotEmpty[List[float]]] = 0.0 1082 """additive term""" 1083 1084 @model_validator(mode="after") 1085 def _validate(self) -> Self: 1086 if isinstance(self.gain, list): 1087 if isinstance(self.offset, list): 1088 if len(self.gain) != len(self.offset): 1089 raise ValueError( 1090 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1091 ) 1092 else: 1093 self.offset = [float(self.offset)] * len(self.gain) 1094 elif isinstance(self.offset, list): 1095 self.gain = [float(self.gain)] * len(self.offset) 1096 else: 1097 raise ValueError( 1098 "Do not specify an `axis` for scalar gain and offset values." 1099 ) 1100 1101 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1102 raise ValueError( 1103 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1104 + " != 0.0." 1105 ) 1106 1107 return self 1108 1109 1110class ScaleLinearDescr(ProcessingDescrBase): 1111 """Fixed linear scaling. 1112 1113 Examples: 1114 1. Scale with scalar gain and offset 1115 - in YAML 1116 ```yaml 1117 preprocessing: 1118 - id: scale_linear 1119 kwargs: 1120 gain: 2.0 1121 offset: 3.0 1122 ``` 1123 - in Python: 1124 >>> preprocessing = [ 1125 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1126 ... ] 1127 1128 2. Independent scaling along an axis 1129 - in YAML 1130 ```yaml 1131 preprocessing: 1132 - id: scale_linear 1133 kwargs: 1134 axis: 'channel' 1135 gain: [1.0, 2.0, 3.0] 1136 ``` 1137 - in Python: 1138 >>> preprocessing = [ 1139 ... ScaleLinearDescr( 1140 ... kwargs=ScaleLinearAlongAxisKwargs( 1141 ... axis=AxisId("channel"), 1142 ... gain=[1.0, 2.0, 3.0], 1143 ... ) 1144 ... ) 1145 ... ] 1146 1147 """ 1148 1149 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1150 if TYPE_CHECKING: 1151 id: Literal["scale_linear"] = "scale_linear" 1152 else: 1153 id: Literal["scale_linear"] 1154 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1155 1156 1157class SigmoidDescr(ProcessingDescrBase): 1158 """The logistic sigmoid function, a.k.a. expit function. 1159 1160 Examples: 1161 - in YAML 1162 ```yaml 1163 postprocessing: 1164 - id: sigmoid 1165 ``` 1166 - in Python: 1167 >>> postprocessing = [SigmoidDescr()] 1168 """ 1169 1170 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1171 if TYPE_CHECKING: 1172 id: Literal["sigmoid"] = "sigmoid" 1173 else: 1174 id: Literal["sigmoid"] 1175 1176 @property 1177 def kwargs(self) -> ProcessingKwargs: 1178 """empty kwargs""" 1179 return ProcessingKwargs() 1180 1181 1182class SoftmaxKwargs(ProcessingKwargs): 1183 """key word arguments for `SoftmaxDescr`""" 1184 1185 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1186 """The axis to apply the softmax function along. 1187 Note: 1188 Defaults to 'channel' axis 1189 (which may not exist, in which case 1190 a different axis id has to be specified). 1191 """ 1192 1193 1194class SoftmaxDescr(ProcessingDescrBase): 1195 """The softmax function. 1196 1197 Examples: 1198 - in YAML 1199 ```yaml 1200 postprocessing: 1201 - id: softmax 1202 kwargs: 1203 axis: channel 1204 ``` 1205 - in Python: 1206 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1207 """ 1208 1209 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1210 if TYPE_CHECKING: 1211 id: Literal["softmax"] = "softmax" 1212 else: 1213 id: Literal["softmax"] 1214 1215 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 1216 1217 1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1219 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1220 1221 mean: float 1222 """The mean value to normalize with.""" 1223 1224 std: Annotated[float, Ge(1e-6)] 1225 """The standard deviation value to normalize with.""" 1226 1227 1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1229 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1230 1231 mean: NotEmpty[List[float]] 1232 """The mean value(s) to normalize with.""" 1233 1234 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1235 """The standard deviation value(s) to normalize with. 1236 Size must match `mean` values.""" 1237 1238 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1239 """The axis of the mean/std values to normalize each entry along that dimension 1240 separately.""" 1241 1242 @model_validator(mode="after") 1243 def _mean_and_std_match(self) -> Self: 1244 if len(self.mean) != len(self.std): 1245 raise ValueError( 1246 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1247 + " must match." 1248 ) 1249 1250 return self 1251 1252 1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1254 """Subtract a given mean and divide by the standard deviation. 1255 1256 Normalize with fixed, precomputed values for 1257 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1258 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1259 axes. 1260 1261 Examples: 1262 1. scalar value for whole tensor 1263 - in YAML 1264 ```yaml 1265 preprocessing: 1266 - id: fixed_zero_mean_unit_variance 1267 kwargs: 1268 mean: 103.5 1269 std: 13.7 1270 ``` 1271 - in Python 1272 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1273 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1274 ... )] 1275 1276 2. independently along an axis 1277 - in YAML 1278 ```yaml 1279 preprocessing: 1280 - id: fixed_zero_mean_unit_variance 1281 kwargs: 1282 axis: channel 1283 mean: [101.5, 102.5, 103.5] 1284 std: [11.7, 12.7, 13.7] 1285 ``` 1286 - in Python 1287 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1288 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1289 ... axis=AxisId("channel"), 1290 ... mean=[101.5, 102.5, 103.5], 1291 ... std=[11.7, 12.7, 13.7], 1292 ... ) 1293 ... )] 1294 """ 1295 1296 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1297 "fixed_zero_mean_unit_variance" 1298 ) 1299 if TYPE_CHECKING: 1300 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1301 else: 1302 id: Literal["fixed_zero_mean_unit_variance"] 1303 1304 kwargs: Union[ 1305 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1306 ] 1307 1308 1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1310 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1311 1312 axes: Annotated[ 1313 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1314 ] = None 1315 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1316 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1317 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1318 To normalize each sample independently leave out the 'batch' axis. 1319 Default: Scale all axes jointly.""" 1320 1321 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1322 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1323 1324 1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1326 """Subtract mean and divide by variance. 1327 1328 Examples: 1329 Subtract tensor mean and variance 1330 - in YAML 1331 ```yaml 1332 preprocessing: 1333 - id: zero_mean_unit_variance 1334 ``` 1335 - in Python 1336 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1337 """ 1338 1339 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1340 "zero_mean_unit_variance" 1341 ) 1342 if TYPE_CHECKING: 1343 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1344 else: 1345 id: Literal["zero_mean_unit_variance"] 1346 1347 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1348 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1349 ) 1350 1351 1352class ScaleRangeKwargs(ProcessingKwargs): 1353 """key word arguments for `ScaleRangeDescr` 1354 1355 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1356 this processing step normalizes data to the [0, 1] intervall. 1357 For other percentiles the normalized values will partially be outside the [0, 1] 1358 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1359 normalized values to a range. 1360 """ 1361 1362 axes: Annotated[ 1363 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1364 ] = None 1365 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1366 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1367 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1368 To normalize samples independently, leave out the "batch" axis. 1369 Default: Scale all axes jointly.""" 1370 1371 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1372 """The lower percentile used to determine the value to align with zero.""" 1373 1374 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1375 """The upper percentile used to determine the value to align with one. 1376 Has to be bigger than `min_percentile`. 1377 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1378 accepting percentiles specified in the range 0.0 to 1.0.""" 1379 1380 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1381 """Epsilon for numeric stability. 1382 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1383 with `v_lower,v_upper` values at the respective percentiles.""" 1384 1385 reference_tensor: Optional[TensorId] = None 1386 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1387 For any tensor in `inputs` only input tensor references are allowed.""" 1388 1389 @field_validator("max_percentile", mode="after") 1390 @classmethod 1391 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1392 if (min_p := info.data["min_percentile"]) >= value: 1393 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1394 1395 return value 1396 1397 1398class ScaleRangeDescr(ProcessingDescrBase): 1399 """Scale with percentiles. 1400 1401 Examples: 1402 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1403 - in YAML 1404 ```yaml 1405 preprocessing: 1406 - id: scale_range 1407 kwargs: 1408 axes: ['y', 'x'] 1409 max_percentile: 99.8 1410 min_percentile: 5.0 1411 ``` 1412 - in Python 1413 >>> preprocessing = [ 1414 ... ScaleRangeDescr( 1415 ... kwargs=ScaleRangeKwargs( 1416 ... axes= (AxisId('y'), AxisId('x')), 1417 ... max_percentile= 99.8, 1418 ... min_percentile= 5.0, 1419 ... ) 1420 ... ), 1421 ... ClipDescr( 1422 ... kwargs=ClipKwargs( 1423 ... min=0.0, 1424 ... max=1.0, 1425 ... ) 1426 ... ), 1427 ... ] 1428 1429 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1430 - in YAML 1431 ```yaml 1432 preprocessing: 1433 - id: scale_range 1434 kwargs: 1435 axes: ['y', 'x'] 1436 max_percentile: 99.8 1437 min_percentile: 5.0 1438 - id: scale_range 1439 - id: clip 1440 kwargs: 1441 min: 0.0 1442 max: 1.0 1443 ``` 1444 - in Python 1445 >>> preprocessing = [ScaleRangeDescr( 1446 ... kwargs=ScaleRangeKwargs( 1447 ... axes= (AxisId('y'), AxisId('x')), 1448 ... max_percentile= 99.8, 1449 ... min_percentile= 5.0, 1450 ... ) 1451 ... )] 1452 1453 """ 1454 1455 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1456 if TYPE_CHECKING: 1457 id: Literal["scale_range"] = "scale_range" 1458 else: 1459 id: Literal["scale_range"] 1460 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 1461 1462 1463class ScaleMeanVarianceKwargs(ProcessingKwargs): 1464 """key word arguments for `ScaleMeanVarianceKwargs`""" 1465 1466 reference_tensor: TensorId 1467 """Name of tensor to match.""" 1468 1469 axes: Annotated[ 1470 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1471 ] = None 1472 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1473 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1474 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1475 To normalize samples independently, leave out the 'batch' axis. 1476 Default: Scale all axes jointly.""" 1477 1478 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1479 """Epsilon for numeric stability: 1480 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1481 1482 1483class ScaleMeanVarianceDescr(ProcessingDescrBase): 1484 """Scale a tensor's data distribution to match another tensor's mean/std. 1485 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1486 """ 1487 1488 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1489 if TYPE_CHECKING: 1490 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1491 else: 1492 id: Literal["scale_mean_variance"] 1493 kwargs: ScaleMeanVarianceKwargs 1494 1495 1496PreprocessingDescr = Annotated[ 1497 Union[ 1498 BinarizeDescr, 1499 ClipDescr, 1500 EnsureDtypeDescr, 1501 FixedZeroMeanUnitVarianceDescr, 1502 ScaleLinearDescr, 1503 ScaleRangeDescr, 1504 SigmoidDescr, 1505 SoftmaxDescr, 1506 ZeroMeanUnitVarianceDescr, 1507 ], 1508 Discriminator("id"), 1509] 1510PostprocessingDescr = Annotated[ 1511 Union[ 1512 BinarizeDescr, 1513 ClipDescr, 1514 EnsureDtypeDescr, 1515 FixedZeroMeanUnitVarianceDescr, 1516 ScaleLinearDescr, 1517 ScaleMeanVarianceDescr, 1518 ScaleRangeDescr, 1519 SigmoidDescr, 1520 SoftmaxDescr, 1521 ZeroMeanUnitVarianceDescr, 1522 ], 1523 Discriminator("id"), 1524] 1525 1526IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1527 1528 1529class TensorDescrBase(Node, Generic[IO_AxisT]): 1530 id: TensorId 1531 """Tensor id. No duplicates are allowed.""" 1532 1533 description: Annotated[str, MaxLen(128)] = "" 1534 """free text description""" 1535 1536 axes: NotEmpty[Sequence[IO_AxisT]] 1537 """tensor axes""" 1538 1539 @property 1540 def shape(self): 1541 return tuple(a.size for a in self.axes) 1542 1543 @field_validator("axes", mode="after", check_fields=False) 1544 @classmethod 1545 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1546 batch_axes = [a for a in axes if a.type == "batch"] 1547 if len(batch_axes) > 1: 1548 raise ValueError( 1549 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1550 ) 1551 1552 seen_ids: Set[AxisId] = set() 1553 duplicate_axes_ids: Set[AxisId] = set() 1554 for a in axes: 1555 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1556 1557 if duplicate_axes_ids: 1558 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1559 1560 return axes 1561 1562 test_tensor: FAIR[Optional[FileDescr_]] = None 1563 """An example tensor to use for testing. 1564 Using the model with the test input tensors is expected to yield the test output tensors. 1565 Each test tensor has be a an ndarray in the 1566 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1567 The file extension must be '.npy'.""" 1568 1569 sample_tensor: FAIR[Optional[FileDescr_]] = None 1570 """A sample tensor to illustrate a possible input/output for the model, 1571 The sample image primarily serves to inform a human user about an example use case 1572 and is typically stored as .hdf5, .png or .tiff. 1573 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1574 (numpy's `.npy` format is not supported). 1575 The image dimensionality has to match the number of axes specified in this tensor description. 1576 """ 1577 1578 @model_validator(mode="after") 1579 def _validate_sample_tensor(self) -> Self: 1580 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1581 return self 1582 1583 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1584 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 1585 reader.read(), 1586 extension=PurePosixPath(reader.original_file_name).suffix, 1587 ) 1588 n_dims = len(tensor.squeeze().shape) 1589 n_dims_min = n_dims_max = len(self.axes) 1590 1591 for a in self.axes: 1592 if isinstance(a, BatchAxis): 1593 n_dims_min -= 1 1594 elif isinstance(a.size, int): 1595 if a.size == 1: 1596 n_dims_min -= 1 1597 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1598 if a.size.min == 1: 1599 n_dims_min -= 1 1600 elif isinstance(a.size, SizeReference): 1601 if a.size.offset < 2: 1602 # size reference may result in singleton axis 1603 n_dims_min -= 1 1604 else: 1605 assert_never(a.size) 1606 1607 n_dims_min = max(0, n_dims_min) 1608 if n_dims < n_dims_min or n_dims > n_dims_max: 1609 raise ValueError( 1610 f"Expected sample tensor to have {n_dims_min} to" 1611 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1612 ) 1613 1614 return self 1615 1616 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1617 IntervalOrRatioDataDescr() 1618 ) 1619 """Description of the tensor's data values, optionally per channel. 1620 If specified per channel, the data `type` needs to match across channels.""" 1621 1622 @property 1623 def dtype( 1624 self, 1625 ) -> Literal[ 1626 "float32", 1627 "float64", 1628 "uint8", 1629 "int8", 1630 "uint16", 1631 "int16", 1632 "uint32", 1633 "int32", 1634 "uint64", 1635 "int64", 1636 "bool", 1637 ]: 1638 """dtype as specified under `data.type` or `data[i].type`""" 1639 if isinstance(self.data, collections.abc.Sequence): 1640 return self.data[0].type 1641 else: 1642 return self.data.type 1643 1644 @field_validator("data", mode="after") 1645 @classmethod 1646 def _check_data_type_across_channels( 1647 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1648 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1649 if not isinstance(value, list): 1650 return value 1651 1652 dtypes = {t.type for t in value} 1653 if len(dtypes) > 1: 1654 raise ValueError( 1655 "Tensor data descriptions per channel need to agree in their data" 1656 + f" `type`, but found {dtypes}." 1657 ) 1658 1659 return value 1660 1661 @model_validator(mode="after") 1662 def _check_data_matches_channelaxis(self) -> Self: 1663 if not isinstance(self.data, (list, tuple)): 1664 return self 1665 1666 for a in self.axes: 1667 if isinstance(a, ChannelAxis): 1668 size = a.size 1669 assert isinstance(size, int) 1670 break 1671 else: 1672 return self 1673 1674 if len(self.data) != size: 1675 raise ValueError( 1676 f"Got tensor data descriptions for {len(self.data)} channels, but" 1677 + f" '{a.id}' axis has size {size}." 1678 ) 1679 1680 return self 1681 1682 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1683 if len(array.shape) != len(self.axes): 1684 raise ValueError( 1685 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1686 + f" incompatible with {len(self.axes)} axes." 1687 ) 1688 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1689 1690 1691class InputTensorDescr(TensorDescrBase[InputAxis]): 1692 id: TensorId = TensorId("input") 1693 """Input tensor id. 1694 No duplicates are allowed across all inputs and outputs.""" 1695 1696 optional: bool = False 1697 """indicates that this tensor may be `None`""" 1698 1699 preprocessing: List[PreprocessingDescr] = Field( 1700 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1701 ) 1702 1703 """Description of how this input should be preprocessed. 1704 1705 notes: 1706 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1707 to ensure an input tensor's data type matches the input tensor's data description. 1708 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1709 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1710 changing the data type. 1711 """ 1712 1713 @model_validator(mode="after") 1714 def _validate_preprocessing_kwargs(self) -> Self: 1715 axes_ids = [a.id for a in self.axes] 1716 for p in self.preprocessing: 1717 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1718 if kwargs_axes is None: 1719 continue 1720 1721 if not isinstance(kwargs_axes, collections.abc.Sequence): 1722 raise ValueError( 1723 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1724 ) 1725 1726 if any(a not in axes_ids for a in kwargs_axes): 1727 raise ValueError( 1728 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1729 ) 1730 1731 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1732 dtype = self.data.type 1733 else: 1734 dtype = self.data[0].type 1735 1736 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1737 if not self.preprocessing or not isinstance( 1738 self.preprocessing[0], EnsureDtypeDescr 1739 ): 1740 self.preprocessing.insert( 1741 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1742 ) 1743 1744 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1745 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1746 self.preprocessing.append( 1747 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1748 ) 1749 1750 return self 1751 1752 1753def convert_axes( 1754 axes: str, 1755 *, 1756 shape: Union[ 1757 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1758 ], 1759 tensor_type: Literal["input", "output"], 1760 halo: Optional[Sequence[int]], 1761 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1762): 1763 ret: List[AnyAxis] = [] 1764 for i, a in enumerate(axes): 1765 axis_type = _AXIS_TYPE_MAP.get(a, a) 1766 if axis_type == "batch": 1767 ret.append(BatchAxis()) 1768 continue 1769 1770 scale = 1.0 1771 if isinstance(shape, _ParameterizedInputShape_v0_4): 1772 if shape.step[i] == 0: 1773 size = shape.min[i] 1774 else: 1775 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1776 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1777 ref_t = str(shape.reference_tensor) 1778 if ref_t.count(".") == 1: 1779 t_id, orig_a_id = ref_t.split(".") 1780 else: 1781 t_id = ref_t 1782 orig_a_id = a 1783 1784 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1785 if not (orig_scale := shape.scale[i]): 1786 # old way to insert a new axis dimension 1787 size = int(2 * shape.offset[i]) 1788 else: 1789 scale = 1 / orig_scale 1790 if axis_type in ("channel", "index"): 1791 # these axes no longer have a scale 1792 offset_from_scale = orig_scale * size_refs.get( 1793 _TensorName_v0_4(t_id), {} 1794 ).get(orig_a_id, 0) 1795 else: 1796 offset_from_scale = 0 1797 size = SizeReference( 1798 tensor_id=TensorId(t_id), 1799 axis_id=AxisId(a_id), 1800 offset=int(offset_from_scale + 2 * shape.offset[i]), 1801 ) 1802 else: 1803 size = shape[i] 1804 1805 if axis_type == "time": 1806 if tensor_type == "input": 1807 ret.append(TimeInputAxis(size=size, scale=scale)) 1808 else: 1809 assert not isinstance(size, ParameterizedSize) 1810 if halo is None: 1811 ret.append(TimeOutputAxis(size=size, scale=scale)) 1812 else: 1813 assert not isinstance(size, int) 1814 ret.append( 1815 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1816 ) 1817 1818 elif axis_type == "index": 1819 if tensor_type == "input": 1820 ret.append(IndexInputAxis(size=size)) 1821 else: 1822 if isinstance(size, ParameterizedSize): 1823 size = DataDependentSize(min=size.min) 1824 1825 ret.append(IndexOutputAxis(size=size)) 1826 elif axis_type == "channel": 1827 assert not isinstance(size, ParameterizedSize) 1828 if isinstance(size, SizeReference): 1829 warnings.warn( 1830 "Conversion of channel size from an implicit output shape may be" 1831 + " wrong" 1832 ) 1833 ret.append( 1834 ChannelAxis( 1835 channel_names=[ 1836 Identifier(f"channel{i}") for i in range(size.offset) 1837 ] 1838 ) 1839 ) 1840 else: 1841 ret.append( 1842 ChannelAxis( 1843 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1844 ) 1845 ) 1846 elif axis_type == "space": 1847 if tensor_type == "input": 1848 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1849 else: 1850 assert not isinstance(size, ParameterizedSize) 1851 if halo is None or halo[i] == 0: 1852 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1853 elif isinstance(size, int): 1854 raise NotImplementedError( 1855 f"output axis with halo and fixed size (here {size}) not allowed" 1856 ) 1857 else: 1858 ret.append( 1859 SpaceOutputAxisWithHalo( 1860 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1861 ) 1862 ) 1863 1864 return ret 1865 1866 1867def _axes_letters_to_ids( 1868 axes: Optional[str], 1869) -> Optional[List[AxisId]]: 1870 if axes is None: 1871 return None 1872 1873 return [AxisId(a) for a in axes] 1874 1875 1876def _get_complement_v04_axis( 1877 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1878) -> Optional[AxisId]: 1879 if axes is None: 1880 return None 1881 1882 non_complement_axes = set(axes) | {"b"} 1883 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1884 if len(complement_axes) > 1: 1885 raise ValueError( 1886 f"Expected none or a single complement axis, but axes '{axes}' " 1887 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1888 ) 1889 1890 return None if not complement_axes else AxisId(complement_axes[0]) 1891 1892 1893def _convert_proc( 1894 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1895 tensor_axes: Sequence[str], 1896) -> Union[PreprocessingDescr, PostprocessingDescr]: 1897 if isinstance(p, _BinarizeDescr_v0_4): 1898 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1899 elif isinstance(p, _ClipDescr_v0_4): 1900 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1901 elif isinstance(p, _SigmoidDescr_v0_4): 1902 return SigmoidDescr() 1903 elif isinstance(p, _ScaleLinearDescr_v0_4): 1904 axes = _axes_letters_to_ids(p.kwargs.axes) 1905 if p.kwargs.axes is None: 1906 axis = None 1907 else: 1908 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1909 1910 if axis is None: 1911 assert not isinstance(p.kwargs.gain, list) 1912 assert not isinstance(p.kwargs.offset, list) 1913 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1914 else: 1915 kwargs = ScaleLinearAlongAxisKwargs( 1916 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1917 ) 1918 return ScaleLinearDescr(kwargs=kwargs) 1919 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1920 return ScaleMeanVarianceDescr( 1921 kwargs=ScaleMeanVarianceKwargs( 1922 axes=_axes_letters_to_ids(p.kwargs.axes), 1923 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1924 eps=p.kwargs.eps, 1925 ) 1926 ) 1927 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1928 if p.kwargs.mode == "fixed": 1929 mean = p.kwargs.mean 1930 std = p.kwargs.std 1931 assert mean is not None 1932 assert std is not None 1933 1934 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1935 1936 if axis is None: 1937 if isinstance(mean, list): 1938 raise ValueError("Expected single float value for mean, not <list>") 1939 if isinstance(std, list): 1940 raise ValueError("Expected single float value for std, not <list>") 1941 return FixedZeroMeanUnitVarianceDescr( 1942 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 1943 mean=mean, 1944 std=std, 1945 ) 1946 ) 1947 else: 1948 if not isinstance(mean, list): 1949 mean = [float(mean)] 1950 if not isinstance(std, list): 1951 std = [float(std)] 1952 1953 return FixedZeroMeanUnitVarianceDescr( 1954 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1955 axis=axis, mean=mean, std=std 1956 ) 1957 ) 1958 1959 else: 1960 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1961 if p.kwargs.mode == "per_dataset": 1962 axes = [AxisId("batch")] + axes 1963 if not axes: 1964 axes = None 1965 return ZeroMeanUnitVarianceDescr( 1966 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1967 ) 1968 1969 elif isinstance(p, _ScaleRangeDescr_v0_4): 1970 return ScaleRangeDescr( 1971 kwargs=ScaleRangeKwargs( 1972 axes=_axes_letters_to_ids(p.kwargs.axes), 1973 min_percentile=p.kwargs.min_percentile, 1974 max_percentile=p.kwargs.max_percentile, 1975 eps=p.kwargs.eps, 1976 ) 1977 ) 1978 else: 1979 assert_never(p) 1980 1981 1982class _InputTensorConv( 1983 Converter[ 1984 _InputTensorDescr_v0_4, 1985 InputTensorDescr, 1986 FileSource_, 1987 Optional[FileSource_], 1988 Mapping[_TensorName_v0_4, Mapping[str, int]], 1989 ] 1990): 1991 def _convert( 1992 self, 1993 src: _InputTensorDescr_v0_4, 1994 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1995 test_tensor: FileSource_, 1996 sample_tensor: Optional[FileSource_], 1997 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1998 ) -> "InputTensorDescr | dict[str, Any]": 1999 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2000 src.axes, 2001 shape=src.shape, 2002 tensor_type="input", 2003 halo=None, 2004 size_refs=size_refs, 2005 ) 2006 prep: List[PreprocessingDescr] = [] 2007 for p in src.preprocessing: 2008 cp = _convert_proc(p, src.axes) 2009 assert not isinstance(cp, ScaleMeanVarianceDescr) 2010 prep.append(cp) 2011 2012 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 2013 2014 return tgt( 2015 axes=axes, 2016 id=TensorId(str(src.name)), 2017 test_tensor=FileDescr(source=test_tensor), 2018 sample_tensor=( 2019 None if sample_tensor is None else FileDescr(source=sample_tensor) 2020 ), 2021 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 2022 preprocessing=prep, 2023 ) 2024 2025 2026_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 2027 2028 2029class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2030 id: TensorId = TensorId("output") 2031 """Output tensor id. 2032 No duplicates are allowed across all inputs and outputs.""" 2033 2034 postprocessing: List[PostprocessingDescr] = Field( 2035 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2036 ) 2037 """Description of how this output should be postprocessed. 2038 2039 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2040 If not given this is added to cast to this tensor's `data.type`. 2041 """ 2042 2043 @model_validator(mode="after") 2044 def _validate_postprocessing_kwargs(self) -> Self: 2045 axes_ids = [a.id for a in self.axes] 2046 for p in self.postprocessing: 2047 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2048 if kwargs_axes is None: 2049 continue 2050 2051 if not isinstance(kwargs_axes, collections.abc.Sequence): 2052 raise ValueError( 2053 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2054 ) 2055 2056 if any(a not in axes_ids for a in kwargs_axes): 2057 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2058 2059 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2060 dtype = self.data.type 2061 else: 2062 dtype = self.data[0].type 2063 2064 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2065 if not self.postprocessing or not isinstance( 2066 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2067 ): 2068 self.postprocessing.append( 2069 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2070 ) 2071 return self 2072 2073 2074class _OutputTensorConv( 2075 Converter[ 2076 _OutputTensorDescr_v0_4, 2077 OutputTensorDescr, 2078 FileSource_, 2079 Optional[FileSource_], 2080 Mapping[_TensorName_v0_4, Mapping[str, int]], 2081 ] 2082): 2083 def _convert( 2084 self, 2085 src: _OutputTensorDescr_v0_4, 2086 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 2087 test_tensor: FileSource_, 2088 sample_tensor: Optional[FileSource_], 2089 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2090 ) -> "OutputTensorDescr | dict[str, Any]": 2091 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2092 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2093 src.axes, 2094 shape=src.shape, 2095 tensor_type="output", 2096 halo=src.halo, 2097 size_refs=size_refs, 2098 ) 2099 data_descr: Dict[str, Any] = dict(type=src.data_type) 2100 if data_descr["type"] == "bool": 2101 data_descr["values"] = [False, True] 2102 2103 return tgt( 2104 axes=axes, 2105 id=TensorId(str(src.name)), 2106 test_tensor=FileDescr(source=test_tensor), 2107 sample_tensor=( 2108 None if sample_tensor is None else FileDescr(source=sample_tensor) 2109 ), 2110 data=data_descr, # pyright: ignore[reportArgumentType] 2111 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2112 ) 2113 2114 2115_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2116 2117 2118TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2119 2120 2121def validate_tensors( 2122 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2123 tensor_origin: Literal[ 2124 "test_tensor" 2125 ], # for more precise error messages, e.g. 'test_tensor' 2126): 2127 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2128 2129 def e_msg(d: TensorDescr): 2130 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2131 2132 for descr, array in tensors.values(): 2133 if array is None: 2134 axis_sizes = {a.id: None for a in descr.axes} 2135 else: 2136 try: 2137 axis_sizes = descr.get_axis_sizes_for_array(array) 2138 except ValueError as e: 2139 raise ValueError(f"{e_msg(descr)} {e}") 2140 2141 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2142 2143 for descr, array in tensors.values(): 2144 if array is None: 2145 continue 2146 2147 if descr.dtype in ("float32", "float64"): 2148 invalid_test_tensor_dtype = array.dtype.name not in ( 2149 "float32", 2150 "float64", 2151 "uint8", 2152 "int8", 2153 "uint16", 2154 "int16", 2155 "uint32", 2156 "int32", 2157 "uint64", 2158 "int64", 2159 ) 2160 else: 2161 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2162 2163 if invalid_test_tensor_dtype: 2164 raise ValueError( 2165 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2166 + f" match described dtype '{descr.dtype}'" 2167 ) 2168 2169 if array.min() > -1e-4 and array.max() < 1e-4: 2170 raise ValueError( 2171 "Output values are too small for reliable testing." 2172 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2173 ) 2174 2175 for a in descr.axes: 2176 actual_size = all_tensor_axes[descr.id][a.id][1] 2177 if actual_size is None: 2178 continue 2179 2180 if a.size is None: 2181 continue 2182 2183 if isinstance(a.size, int): 2184 if actual_size != a.size: 2185 raise ValueError( 2186 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2187 + f"has incompatible size {actual_size}, expected {a.size}" 2188 ) 2189 elif isinstance(a.size, ParameterizedSize): 2190 _ = a.size.validate_size(actual_size) 2191 elif isinstance(a.size, DataDependentSize): 2192 _ = a.size.validate_size(actual_size) 2193 elif isinstance(a.size, SizeReference): 2194 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2195 if ref_tensor_axes is None: 2196 raise ValueError( 2197 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2198 + f" reference '{a.size.tensor_id}'" 2199 ) 2200 2201 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2202 if ref_axis is None or ref_size is None: 2203 raise ValueError( 2204 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2205 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2206 ) 2207 2208 if a.unit != ref_axis.unit: 2209 raise ValueError( 2210 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2211 + " axis and reference axis to have the same `unit`, but" 2212 + f" {a.unit}!={ref_axis.unit}" 2213 ) 2214 2215 if actual_size != ( 2216 expected_size := ( 2217 ref_size * ref_axis.scale / a.scale + a.size.offset 2218 ) 2219 ): 2220 raise ValueError( 2221 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2222 + f" {actual_size} invalid for referenced size {ref_size};" 2223 + f" expected {expected_size}" 2224 ) 2225 else: 2226 assert_never(a.size) 2227 2228 2229FileDescr_dependencies = Annotated[ 2230 FileDescr_, 2231 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2232 Field(examples=[dict(source="environment.yaml")]), 2233] 2234 2235 2236class _ArchitectureCallableDescr(Node): 2237 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2238 """Identifier of the callable that returns a torch.nn.Module instance.""" 2239 2240 kwargs: Dict[str, YamlValue] = Field( 2241 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 2242 ) 2243 """key word arguments for the `callable`""" 2244 2245 2246class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2247 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2248 """Architecture source file""" 2249 2250 @model_serializer(mode="wrap", when_used="unless-none") 2251 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2252 return package_file_descr_serializer(self, nxt, info) 2253 2254 2255class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2256 import_from: str 2257 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2258 2259 2260class _ArchFileConv( 2261 Converter[ 2262 _CallableFromFile_v0_4, 2263 ArchitectureFromFileDescr, 2264 Optional[Sha256], 2265 Dict[str, Any], 2266 ] 2267): 2268 def _convert( 2269 self, 2270 src: _CallableFromFile_v0_4, 2271 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2272 sha256: Optional[Sha256], 2273 kwargs: Dict[str, Any], 2274 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2275 if src.startswith("http") and src.count(":") == 2: 2276 http, source, callable_ = src.split(":") 2277 source = ":".join((http, source)) 2278 elif not src.startswith("http") and src.count(":") == 1: 2279 source, callable_ = src.split(":") 2280 else: 2281 source = str(src) 2282 callable_ = str(src) 2283 return tgt( 2284 callable=Identifier(callable_), 2285 source=cast(FileSource_, source), 2286 sha256=sha256, 2287 kwargs=kwargs, 2288 ) 2289 2290 2291_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2292 2293 2294class _ArchLibConv( 2295 Converter[ 2296 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2297 ] 2298): 2299 def _convert( 2300 self, 2301 src: _CallableFromDepencency_v0_4, 2302 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2303 kwargs: Dict[str, Any], 2304 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2305 *mods, callable_ = src.split(".") 2306 import_from = ".".join(mods) 2307 return tgt( 2308 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2309 ) 2310 2311 2312_arch_lib_conv = _ArchLibConv( 2313 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2314) 2315 2316 2317class WeightsEntryDescrBase(FileDescr): 2318 type: ClassVar[WeightsFormat] 2319 weights_format_name: ClassVar[str] # human readable 2320 2321 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2322 """Source of the weights file.""" 2323 2324 authors: Optional[List[Author]] = None 2325 """Authors 2326 Either the person(s) that have trained this model resulting in the original weights file. 2327 (If this is the initial weights entry, i.e. it does not have a `parent`) 2328 Or the person(s) who have converted the weights to this weights format. 2329 (If this is a child weight, i.e. it has a `parent` field) 2330 """ 2331 2332 parent: Annotated[ 2333 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2334 ] = None 2335 """The source weights these weights were converted from. 2336 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2337 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2338 All weight entries except one (the initial set of weights resulting from training the model), 2339 need to have this field.""" 2340 2341 comment: str = "" 2342 """A comment about this weights entry, for example how these weights were created.""" 2343 2344 @model_validator(mode="after") 2345 def _validate(self) -> Self: 2346 if self.type == self.parent: 2347 raise ValueError("Weights entry can't be it's own parent.") 2348 2349 return self 2350 2351 @model_serializer(mode="wrap", when_used="unless-none") 2352 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2353 return package_file_descr_serializer(self, nxt, info) 2354 2355 2356class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2357 type = "keras_hdf5" 2358 weights_format_name: ClassVar[str] = "Keras HDF5" 2359 tensorflow_version: Version 2360 """TensorFlow version used to create these weights.""" 2361 2362 2363FileDescr_external_data = Annotated[ 2364 FileDescr_, 2365 WithSuffix(".data", case_sensitive=True), 2366 Field(examples=[dict(source="weights.onnx.data")]), 2367] 2368 2369 2370class OnnxWeightsDescr(WeightsEntryDescrBase): 2371 type = "onnx" 2372 weights_format_name: ClassVar[str] = "ONNX" 2373 opset_version: Annotated[int, Ge(7)] 2374 """ONNX opset version""" 2375 2376 external_data: Optional[FileDescr_external_data] = None 2377 """Source of the external ONNX data file holding the weights. 2378 (If present **source** holds the ONNX architecture without weights).""" 2379 2380 @model_validator(mode="after") 2381 def _validate_external_data_unique_file_name(self) -> Self: 2382 if self.external_data is not None and ( 2383 extract_file_name(self.source) 2384 == extract_file_name(self.external_data.source) 2385 ): 2386 raise ValueError( 2387 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 2388 + " must be different from ONNX `source` file name." 2389 ) 2390 2391 return self 2392 2393 2394class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2395 type = "pytorch_state_dict" 2396 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2397 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2398 pytorch_version: Version 2399 """Version of the PyTorch library used. 2400 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2401 """ 2402 dependencies: Optional[FileDescr_dependencies] = None 2403 """Custom depencies beyond pytorch described in a Conda environment file. 2404 Allows to specify custom dependencies, see conda docs: 2405 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2406 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2407 2408 The conda environment file should include pytorch and any version pinning has to be compatible with 2409 **pytorch_version**. 2410 """ 2411 2412 2413class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2414 type = "tensorflow_js" 2415 weights_format_name: ClassVar[str] = "Tensorflow.js" 2416 tensorflow_version: Version 2417 """Version of the TensorFlow library used.""" 2418 2419 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2420 """The multi-file weights. 2421 All required files/folders should be a zip archive.""" 2422 2423 2424class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2425 type = "tensorflow_saved_model_bundle" 2426 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2427 tensorflow_version: Version 2428 """Version of the TensorFlow library used.""" 2429 2430 dependencies: Optional[FileDescr_dependencies] = None 2431 """Custom dependencies beyond tensorflow. 2432 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2433 2434 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2435 """The multi-file weights. 2436 All required files/folders should be a zip archive.""" 2437 2438 2439class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2440 type = "torchscript" 2441 weights_format_name: ClassVar[str] = "TorchScript" 2442 pytorch_version: Version 2443 """Version of the PyTorch library used.""" 2444 2445 2446class WeightsDescr(Node): 2447 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2448 onnx: Optional[OnnxWeightsDescr] = None 2449 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2450 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2451 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2452 None 2453 ) 2454 torchscript: Optional[TorchscriptWeightsDescr] = None 2455 2456 @model_validator(mode="after") 2457 def check_entries(self) -> Self: 2458 entries = {wtype for wtype, entry in self if entry is not None} 2459 2460 if not entries: 2461 raise ValueError("Missing weights entry") 2462 2463 entries_wo_parent = { 2464 wtype 2465 for wtype, entry in self 2466 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2467 } 2468 if len(entries_wo_parent) != 1: 2469 issue_warning( 2470 "Exactly one weights entry may not specify the `parent` field (got" 2471 + " {value}). That entry is considered the original set of model weights." 2472 + " Other weight formats are created through conversion of the orignal or" 2473 + " already converted weights. They have to reference the weights format" 2474 + " they were converted from as their `parent`.", 2475 value=len(entries_wo_parent), 2476 field="weights", 2477 ) 2478 2479 for wtype, entry in self: 2480 if entry is None: 2481 continue 2482 2483 assert hasattr(entry, "type") 2484 assert hasattr(entry, "parent") 2485 assert wtype == entry.type 2486 if ( 2487 entry.parent is not None and entry.parent not in entries 2488 ): # self reference checked for `parent` field 2489 raise ValueError( 2490 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2491 + f" formats: {entries}" 2492 ) 2493 2494 return self 2495 2496 def __getitem__( 2497 self, 2498 key: Literal[ 2499 "keras_hdf5", 2500 "onnx", 2501 "pytorch_state_dict", 2502 "tensorflow_js", 2503 "tensorflow_saved_model_bundle", 2504 "torchscript", 2505 ], 2506 ): 2507 if key == "keras_hdf5": 2508 ret = self.keras_hdf5 2509 elif key == "onnx": 2510 ret = self.onnx 2511 elif key == "pytorch_state_dict": 2512 ret = self.pytorch_state_dict 2513 elif key == "tensorflow_js": 2514 ret = self.tensorflow_js 2515 elif key == "tensorflow_saved_model_bundle": 2516 ret = self.tensorflow_saved_model_bundle 2517 elif key == "torchscript": 2518 ret = self.torchscript 2519 else: 2520 raise KeyError(key) 2521 2522 if ret is None: 2523 raise KeyError(key) 2524 2525 return ret 2526 2527 @property 2528 def available_formats(self): 2529 return { 2530 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2531 **({} if self.onnx is None else {"onnx": self.onnx}), 2532 **( 2533 {} 2534 if self.pytorch_state_dict is None 2535 else {"pytorch_state_dict": self.pytorch_state_dict} 2536 ), 2537 **( 2538 {} 2539 if self.tensorflow_js is None 2540 else {"tensorflow_js": self.tensorflow_js} 2541 ), 2542 **( 2543 {} 2544 if self.tensorflow_saved_model_bundle is None 2545 else { 2546 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2547 } 2548 ), 2549 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2550 } 2551 2552 @property 2553 def missing_formats(self): 2554 return { 2555 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2556 } 2557 2558 2559class ModelId(ResourceId): 2560 pass 2561 2562 2563class LinkedModel(LinkedResourceBase): 2564 """Reference to a bioimage.io model.""" 2565 2566 id: ModelId 2567 """A valid model `id` from the bioimage.io collection.""" 2568 2569 2570class _DataDepSize(NamedTuple): 2571 min: StrictInt 2572 max: Optional[StrictInt] 2573 2574 2575class _AxisSizes(NamedTuple): 2576 """the lenghts of all axes of model inputs and outputs""" 2577 2578 inputs: Dict[Tuple[TensorId, AxisId], int] 2579 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2580 2581 2582class _TensorSizes(NamedTuple): 2583 """_AxisSizes as nested dicts""" 2584 2585 inputs: Dict[TensorId, Dict[AxisId, int]] 2586 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2587 2588 2589class ReproducibilityTolerance(Node, extra="allow"): 2590 """Describes what small numerical differences -- if any -- may be tolerated 2591 in the generated output when executing in different environments. 2592 2593 A tensor element *output* is considered mismatched to the **test_tensor** if 2594 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2595 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2596 2597 Motivation: 2598 For testing we can request the respective deep learning frameworks to be as 2599 reproducible as possible by setting seeds and chosing deterministic algorithms, 2600 but differences in operating systems, available hardware and installed drivers 2601 may still lead to numerical differences. 2602 """ 2603 2604 relative_tolerance: RelativeTolerance = 1e-3 2605 """Maximum relative tolerance of reproduced test tensor.""" 2606 2607 absolute_tolerance: AbsoluteTolerance = 1e-4 2608 """Maximum absolute tolerance of reproduced test tensor.""" 2609 2610 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2611 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2612 2613 output_ids: Sequence[TensorId] = () 2614 """Limits the output tensor IDs these reproducibility details apply to.""" 2615 2616 weights_formats: Sequence[WeightsFormat] = () 2617 """Limits the weights formats these details apply to.""" 2618 2619 2620class BioimageioConfig(Node, extra="allow"): 2621 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2622 """Tolerances to allow when reproducing the model's test outputs 2623 from the model's test inputs. 2624 Only the first entry matching tensor id and weights format is considered. 2625 """ 2626 2627 2628class Config(Node, extra="allow"): 2629 bioimageio: BioimageioConfig = Field( 2630 default_factory=BioimageioConfig.model_construct 2631 ) 2632 2633 2634class ModelDescr(GenericModelDescrBase): 2635 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2636 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2637 """ 2638 2639 implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6" 2640 if TYPE_CHECKING: 2641 format_version: Literal["0.5.6"] = "0.5.6" 2642 else: 2643 format_version: Literal["0.5.6"] 2644 """Version of the bioimage.io model description specification used. 2645 When creating a new model always use the latest micro/patch version described here. 2646 The `format_version` is important for any consumer software to understand how to parse the fields. 2647 """ 2648 2649 implemented_type: ClassVar[Literal["model"]] = "model" 2650 if TYPE_CHECKING: 2651 type: Literal["model"] = "model" 2652 else: 2653 type: Literal["model"] 2654 """Specialized resource type 'model'""" 2655 2656 id: Optional[ModelId] = None 2657 """bioimage.io-wide unique resource identifier 2658 assigned by bioimage.io; version **un**specific.""" 2659 2660 authors: FAIR[List[Author]] = Field( 2661 default_factory=cast(Callable[[], List[Author]], list) 2662 ) 2663 """The authors are the creators of the model RDF and the primary points of contact.""" 2664 2665 documentation: FAIR[Optional[FileSource_documentation]] = None 2666 """URL or relative path to a markdown file with additional documentation. 2667 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2668 The documentation should include a '#[#] Validation' (sub)section 2669 with details on how to quantitatively validate the model on unseen data.""" 2670 2671 @field_validator("documentation", mode="after") 2672 @classmethod 2673 def _validate_documentation( 2674 cls, value: Optional[FileSource_documentation] 2675 ) -> Optional[FileSource_documentation]: 2676 if not get_validation_context().perform_io_checks or value is None: 2677 return value 2678 2679 doc_reader = get_reader(value) 2680 doc_content = doc_reader.read().decode(encoding="utf-8") 2681 if not re.search("#.*[vV]alidation", doc_content): 2682 issue_warning( 2683 "No '# Validation' (sub)section found in {value}.", 2684 value=value, 2685 field="documentation", 2686 ) 2687 2688 return value 2689 2690 inputs: NotEmpty[Sequence[InputTensorDescr]] 2691 """Describes the input tensors expected by this model.""" 2692 2693 @field_validator("inputs", mode="after") 2694 @classmethod 2695 def _validate_input_axes( 2696 cls, inputs: Sequence[InputTensorDescr] 2697 ) -> Sequence[InputTensorDescr]: 2698 input_size_refs = cls._get_axes_with_independent_size(inputs) 2699 2700 for i, ipt in enumerate(inputs): 2701 valid_independent_refs: Dict[ 2702 Tuple[TensorId, AxisId], 2703 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2704 ] = { 2705 **{ 2706 (ipt.id, a.id): (ipt, a, a.size) 2707 for a in ipt.axes 2708 if not isinstance(a, BatchAxis) 2709 and isinstance(a.size, (int, ParameterizedSize)) 2710 }, 2711 **input_size_refs, 2712 } 2713 for a, ax in enumerate(ipt.axes): 2714 cls._validate_axis( 2715 "inputs", 2716 i=i, 2717 tensor_id=ipt.id, 2718 a=a, 2719 axis=ax, 2720 valid_independent_refs=valid_independent_refs, 2721 ) 2722 return inputs 2723 2724 @staticmethod 2725 def _validate_axis( 2726 field_name: str, 2727 i: int, 2728 tensor_id: TensorId, 2729 a: int, 2730 axis: AnyAxis, 2731 valid_independent_refs: Dict[ 2732 Tuple[TensorId, AxisId], 2733 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2734 ], 2735 ): 2736 if isinstance(axis, BatchAxis) or isinstance( 2737 axis.size, (int, ParameterizedSize, DataDependentSize) 2738 ): 2739 return 2740 elif not isinstance(axis.size, SizeReference): 2741 assert_never(axis.size) 2742 2743 # validate axis.size SizeReference 2744 ref = (axis.size.tensor_id, axis.size.axis_id) 2745 if ref not in valid_independent_refs: 2746 raise ValueError( 2747 "Invalid tensor axis reference at" 2748 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2749 ) 2750 if ref == (tensor_id, axis.id): 2751 raise ValueError( 2752 "Self-referencing not allowed for" 2753 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2754 ) 2755 if axis.type == "channel": 2756 if valid_independent_refs[ref][1].type != "channel": 2757 raise ValueError( 2758 "A channel axis' size may only reference another fixed size" 2759 + " channel axis." 2760 ) 2761 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2762 ref_size = valid_independent_refs[ref][2] 2763 assert isinstance(ref_size, int), ( 2764 "channel axis ref (another channel axis) has to specify fixed" 2765 + " size" 2766 ) 2767 generated_channel_names = [ 2768 Identifier(axis.channel_names.format(i=i)) 2769 for i in range(1, ref_size + 1) 2770 ] 2771 axis.channel_names = generated_channel_names 2772 2773 if (ax_unit := getattr(axis, "unit", None)) != ( 2774 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2775 ): 2776 raise ValueError( 2777 "The units of an axis and its reference axis need to match, but" 2778 + f" '{ax_unit}' != '{ref_unit}'." 2779 ) 2780 ref_axis = valid_independent_refs[ref][1] 2781 if isinstance(ref_axis, BatchAxis): 2782 raise ValueError( 2783 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2784 + " (a batch axis is not allowed as reference)." 2785 ) 2786 2787 if isinstance(axis, WithHalo): 2788 min_size = axis.size.get_size(axis, ref_axis, n=0) 2789 if (min_size - 2 * axis.halo) < 1: 2790 raise ValueError( 2791 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2792 + f" {axis.halo}." 2793 ) 2794 2795 input_halo = axis.halo * axis.scale / ref_axis.scale 2796 if input_halo != int(input_halo) or input_halo % 2 == 1: 2797 raise ValueError( 2798 f"input_halo {input_halo} (output_halo {axis.halo} *" 2799 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2800 + f" {tensor_id}.{axis.id}." 2801 ) 2802 2803 @model_validator(mode="after") 2804 def _validate_test_tensors(self) -> Self: 2805 if not get_validation_context().perform_io_checks: 2806 return self 2807 2808 test_output_arrays = [ 2809 None if descr.test_tensor is None else load_array(descr.test_tensor) 2810 for descr in self.outputs 2811 ] 2812 test_input_arrays = [ 2813 None if descr.test_tensor is None else load_array(descr.test_tensor) 2814 for descr in self.inputs 2815 ] 2816 2817 tensors = { 2818 descr.id: (descr, array) 2819 for descr, array in zip( 2820 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2821 ) 2822 } 2823 validate_tensors(tensors, tensor_origin="test_tensor") 2824 2825 output_arrays = { 2826 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2827 } 2828 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2829 if not rep_tol.absolute_tolerance: 2830 continue 2831 2832 if rep_tol.output_ids: 2833 out_arrays = { 2834 oid: a 2835 for oid, a in output_arrays.items() 2836 if oid in rep_tol.output_ids 2837 } 2838 else: 2839 out_arrays = output_arrays 2840 2841 for out_id, array in out_arrays.items(): 2842 if array is None: 2843 continue 2844 2845 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2846 raise ValueError( 2847 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2848 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2849 + f" (1% of the maximum value of the test tensor '{out_id}')" 2850 ) 2851 2852 return self 2853 2854 @model_validator(mode="after") 2855 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2856 ipt_refs = {t.id for t in self.inputs} 2857 out_refs = {t.id for t in self.outputs} 2858 for ipt in self.inputs: 2859 for p in ipt.preprocessing: 2860 ref = p.kwargs.get("reference_tensor") 2861 if ref is None: 2862 continue 2863 if ref not in ipt_refs: 2864 raise ValueError( 2865 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2866 + f" references are: {ipt_refs}." 2867 ) 2868 2869 for out in self.outputs: 2870 for p in out.postprocessing: 2871 ref = p.kwargs.get("reference_tensor") 2872 if ref is None: 2873 continue 2874 2875 if ref not in ipt_refs and ref not in out_refs: 2876 raise ValueError( 2877 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2878 + f" are: {ipt_refs | out_refs}." 2879 ) 2880 2881 return self 2882 2883 # TODO: use validate funcs in validate_test_tensors 2884 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2885 2886 name: Annotated[ 2887 str, 2888 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2889 MinLen(5), 2890 MaxLen(128), 2891 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2892 ] 2893 """A human-readable name of this model. 2894 It should be no longer than 64 characters 2895 and may only contain letter, number, underscore, minus, parentheses and spaces. 2896 We recommend to chose a name that refers to the model's task and image modality. 2897 """ 2898 2899 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2900 """Describes the output tensors.""" 2901 2902 @field_validator("outputs", mode="after") 2903 @classmethod 2904 def _validate_tensor_ids( 2905 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2906 ) -> Sequence[OutputTensorDescr]: 2907 tensor_ids = [ 2908 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2909 ] 2910 duplicate_tensor_ids: List[str] = [] 2911 seen: Set[str] = set() 2912 for t in tensor_ids: 2913 if t in seen: 2914 duplicate_tensor_ids.append(t) 2915 2916 seen.add(t) 2917 2918 if duplicate_tensor_ids: 2919 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2920 2921 return outputs 2922 2923 @staticmethod 2924 def _get_axes_with_parameterized_size( 2925 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2926 ): 2927 return { 2928 f"{t.id}.{a.id}": (t, a, a.size) 2929 for t in io 2930 for a in t.axes 2931 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2932 } 2933 2934 @staticmethod 2935 def _get_axes_with_independent_size( 2936 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2937 ): 2938 return { 2939 (t.id, a.id): (t, a, a.size) 2940 for t in io 2941 for a in t.axes 2942 if not isinstance(a, BatchAxis) 2943 and isinstance(a.size, (int, ParameterizedSize)) 2944 } 2945 2946 @field_validator("outputs", mode="after") 2947 @classmethod 2948 def _validate_output_axes( 2949 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2950 ) -> List[OutputTensorDescr]: 2951 input_size_refs = cls._get_axes_with_independent_size( 2952 info.data.get("inputs", []) 2953 ) 2954 output_size_refs = cls._get_axes_with_independent_size(outputs) 2955 2956 for i, out in enumerate(outputs): 2957 valid_independent_refs: Dict[ 2958 Tuple[TensorId, AxisId], 2959 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2960 ] = { 2961 **{ 2962 (out.id, a.id): (out, a, a.size) 2963 for a in out.axes 2964 if not isinstance(a, BatchAxis) 2965 and isinstance(a.size, (int, ParameterizedSize)) 2966 }, 2967 **input_size_refs, 2968 **output_size_refs, 2969 } 2970 for a, ax in enumerate(out.axes): 2971 cls._validate_axis( 2972 "outputs", 2973 i, 2974 out.id, 2975 a, 2976 ax, 2977 valid_independent_refs=valid_independent_refs, 2978 ) 2979 2980 return outputs 2981 2982 packaged_by: List[Author] = Field( 2983 default_factory=cast(Callable[[], List[Author]], list) 2984 ) 2985 """The persons that have packaged and uploaded this model. 2986 Only required if those persons differ from the `authors`.""" 2987 2988 parent: Optional[LinkedModel] = None 2989 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2990 2991 @model_validator(mode="after") 2992 def _validate_parent_is_not_self(self) -> Self: 2993 if self.parent is not None and self.parent.id == self.id: 2994 raise ValueError("A model description may not reference itself as parent.") 2995 2996 return self 2997 2998 run_mode: Annotated[ 2999 Optional[RunMode], 3000 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 3001 ] = None 3002 """Custom run mode for this model: for more complex prediction procedures like test time 3003 data augmentation that currently cannot be expressed in the specification. 3004 No standard run modes are defined yet.""" 3005 3006 timestamp: Datetime = Field(default_factory=Datetime.now) 3007 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 3008 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 3009 (In Python a datetime object is valid, too).""" 3010 3011 training_data: Annotated[ 3012 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 3013 Field(union_mode="left_to_right"), 3014 ] = None 3015 """The dataset used to train this model""" 3016 3017 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 3018 """The weights for this model. 3019 Weights can be given for different formats, but should otherwise be equivalent. 3020 The available weight formats determine which consumers can use this model.""" 3021 3022 config: Config = Field(default_factory=Config.model_construct) 3023 3024 @model_validator(mode="after") 3025 def _add_default_cover(self) -> Self: 3026 if not get_validation_context().perform_io_checks or self.covers: 3027 return self 3028 3029 try: 3030 generated_covers = generate_covers( 3031 [ 3032 (t, load_array(t.test_tensor)) 3033 for t in self.inputs 3034 if t.test_tensor is not None 3035 ], 3036 [ 3037 (t, load_array(t.test_tensor)) 3038 for t in self.outputs 3039 if t.test_tensor is not None 3040 ], 3041 ) 3042 except Exception as e: 3043 issue_warning( 3044 "Failed to generate cover image(s): {e}", 3045 value=self.covers, 3046 msg_context=dict(e=e), 3047 field="covers", 3048 ) 3049 else: 3050 self.covers.extend(generated_covers) 3051 3052 return self 3053 3054 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3055 return self._get_test_arrays(self.inputs) 3056 3057 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3058 return self._get_test_arrays(self.outputs) 3059 3060 @staticmethod 3061 def _get_test_arrays( 3062 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3063 ): 3064 ts: List[FileDescr] = [] 3065 for d in io_descr: 3066 if d.test_tensor is None: 3067 raise ValueError( 3068 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3069 ) 3070 ts.append(d.test_tensor) 3071 3072 data = [load_array(t) for t in ts] 3073 assert all(isinstance(d, np.ndarray) for d in data) 3074 return data 3075 3076 @staticmethod 3077 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3078 batch_size = 1 3079 tensor_with_batchsize: Optional[TensorId] = None 3080 for tid in tensor_sizes: 3081 for aid, s in tensor_sizes[tid].items(): 3082 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3083 continue 3084 3085 if batch_size != 1: 3086 assert tensor_with_batchsize is not None 3087 raise ValueError( 3088 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3089 ) 3090 3091 batch_size = s 3092 tensor_with_batchsize = tid 3093 3094 return batch_size 3095 3096 def get_output_tensor_sizes( 3097 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3098 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3099 """Returns the tensor output sizes for given **input_sizes**. 3100 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3101 Otherwise it might be larger than the actual (valid) output""" 3102 batch_size = self.get_batch_size(input_sizes) 3103 ns = self.get_ns(input_sizes) 3104 3105 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3106 return tensor_sizes.outputs 3107 3108 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3109 """get parameter `n` for each parameterized axis 3110 such that the valid input size is >= the given input size""" 3111 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3112 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3113 for tid in input_sizes: 3114 for aid, s in input_sizes[tid].items(): 3115 size_descr = axes[tid][aid].size 3116 if isinstance(size_descr, ParameterizedSize): 3117 ret[(tid, aid)] = size_descr.get_n(s) 3118 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3119 pass 3120 else: 3121 assert_never(size_descr) 3122 3123 return ret 3124 3125 def get_tensor_sizes( 3126 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3127 ) -> _TensorSizes: 3128 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3129 return _TensorSizes( 3130 { 3131 t: { 3132 aa: axis_sizes.inputs[(tt, aa)] 3133 for tt, aa in axis_sizes.inputs 3134 if tt == t 3135 } 3136 for t in {tt for tt, _ in axis_sizes.inputs} 3137 }, 3138 { 3139 t: { 3140 aa: axis_sizes.outputs[(tt, aa)] 3141 for tt, aa in axis_sizes.outputs 3142 if tt == t 3143 } 3144 for t in {tt for tt, _ in axis_sizes.outputs} 3145 }, 3146 ) 3147 3148 def get_axis_sizes( 3149 self, 3150 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3151 batch_size: Optional[int] = None, 3152 *, 3153 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3154 ) -> _AxisSizes: 3155 """Determine input and output block shape for scale factors **ns** 3156 of parameterized input sizes. 3157 3158 Args: 3159 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3160 that is parameterized as `size = min + n * step`. 3161 batch_size: The desired size of the batch dimension. 3162 If given **batch_size** overwrites any batch size present in 3163 **max_input_shape**. Default 1. 3164 max_input_shape: Limits the derived block shapes. 3165 Each axis for which the input size, parameterized by `n`, is larger 3166 than **max_input_shape** is set to the minimal value `n_min` for which 3167 this is still true. 3168 Use this for small input samples or large values of **ns**. 3169 Or simply whenever you know the full input shape. 3170 3171 Returns: 3172 Resolved axis sizes for model inputs and outputs. 3173 """ 3174 max_input_shape = max_input_shape or {} 3175 if batch_size is None: 3176 for (_t_id, a_id), s in max_input_shape.items(): 3177 if a_id == BATCH_AXIS_ID: 3178 batch_size = s 3179 break 3180 else: 3181 batch_size = 1 3182 3183 all_axes = { 3184 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3185 } 3186 3187 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3188 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3189 3190 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3191 if isinstance(a, BatchAxis): 3192 if (t_descr.id, a.id) in ns: 3193 logger.warning( 3194 "Ignoring unexpected size increment factor (n) for batch axis" 3195 + " of tensor '{}'.", 3196 t_descr.id, 3197 ) 3198 return batch_size 3199 elif isinstance(a.size, int): 3200 if (t_descr.id, a.id) in ns: 3201 logger.warning( 3202 "Ignoring unexpected size increment factor (n) for fixed size" 3203 + " axis '{}' of tensor '{}'.", 3204 a.id, 3205 t_descr.id, 3206 ) 3207 return a.size 3208 elif isinstance(a.size, ParameterizedSize): 3209 if (t_descr.id, a.id) not in ns: 3210 raise ValueError( 3211 "Size increment factor (n) missing for parametrized axis" 3212 + f" '{a.id}' of tensor '{t_descr.id}'." 3213 ) 3214 n = ns[(t_descr.id, a.id)] 3215 s_max = max_input_shape.get((t_descr.id, a.id)) 3216 if s_max is not None: 3217 n = min(n, a.size.get_n(s_max)) 3218 3219 return a.size.get_size(n) 3220 3221 elif isinstance(a.size, SizeReference): 3222 if (t_descr.id, a.id) in ns: 3223 logger.warning( 3224 "Ignoring unexpected size increment factor (n) for axis '{}'" 3225 + " of tensor '{}' with size reference.", 3226 a.id, 3227 t_descr.id, 3228 ) 3229 assert not isinstance(a, BatchAxis) 3230 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3231 assert not isinstance(ref_axis, BatchAxis) 3232 ref_key = (a.size.tensor_id, a.size.axis_id) 3233 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3234 assert ref_size is not None, ref_key 3235 assert not isinstance(ref_size, _DataDepSize), ref_key 3236 return a.size.get_size( 3237 axis=a, 3238 ref_axis=ref_axis, 3239 ref_size=ref_size, 3240 ) 3241 elif isinstance(a.size, DataDependentSize): 3242 if (t_descr.id, a.id) in ns: 3243 logger.warning( 3244 "Ignoring unexpected increment factor (n) for data dependent" 3245 + " size axis '{}' of tensor '{}'.", 3246 a.id, 3247 t_descr.id, 3248 ) 3249 return _DataDepSize(a.size.min, a.size.max) 3250 else: 3251 assert_never(a.size) 3252 3253 # first resolve all , but the `SizeReference` input sizes 3254 for t_descr in self.inputs: 3255 for a in t_descr.axes: 3256 if not isinstance(a.size, SizeReference): 3257 s = get_axis_size(a) 3258 assert not isinstance(s, _DataDepSize) 3259 inputs[t_descr.id, a.id] = s 3260 3261 # resolve all other input axis sizes 3262 for t_descr in self.inputs: 3263 for a in t_descr.axes: 3264 if isinstance(a.size, SizeReference): 3265 s = get_axis_size(a) 3266 assert not isinstance(s, _DataDepSize) 3267 inputs[t_descr.id, a.id] = s 3268 3269 # resolve all output axis sizes 3270 for t_descr in self.outputs: 3271 for a in t_descr.axes: 3272 assert not isinstance(a.size, ParameterizedSize) 3273 s = get_axis_size(a) 3274 outputs[t_descr.id, a.id] = s 3275 3276 return _AxisSizes(inputs=inputs, outputs=outputs) 3277 3278 @model_validator(mode="before") 3279 @classmethod 3280 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3281 cls.convert_from_old_format_wo_validation(data) 3282 return data 3283 3284 @classmethod 3285 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3286 """Convert metadata following an older format version to this classes' format 3287 without validating the result. 3288 """ 3289 if ( 3290 data.get("type") == "model" 3291 and isinstance(fv := data.get("format_version"), str) 3292 and fv.count(".") == 2 3293 ): 3294 fv_parts = fv.split(".") 3295 if any(not p.isdigit() for p in fv_parts): 3296 return 3297 3298 fv_tuple = tuple(map(int, fv_parts)) 3299 3300 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3301 if fv_tuple[:2] in ((0, 3), (0, 4)): 3302 m04 = _ModelDescr_v0_4.load(data) 3303 if isinstance(m04, InvalidDescr): 3304 try: 3305 updated = _model_conv.convert_as_dict( 3306 m04 # pyright: ignore[reportArgumentType] 3307 ) 3308 except Exception as e: 3309 logger.error( 3310 "Failed to convert from invalid model 0.4 description." 3311 + f"\nerror: {e}" 3312 + "\nProceeding with model 0.5 validation without conversion." 3313 ) 3314 updated = None 3315 else: 3316 updated = _model_conv.convert_as_dict(m04) 3317 3318 if updated is not None: 3319 data.clear() 3320 data.update(updated) 3321 3322 elif fv_tuple[:2] == (0, 5): 3323 # bump patch version 3324 data["format_version"] = cls.implemented_format_version 3325 3326 3327class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3328 def _convert( 3329 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3330 ) -> "ModelDescr | dict[str, Any]": 3331 name = "".join( 3332 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3333 for c in src.name 3334 ) 3335 3336 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3337 conv = ( 3338 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3339 ) 3340 return None if auths is None else [conv(a) for a in auths] 3341 3342 if TYPE_CHECKING: 3343 arch_file_conv = _arch_file_conv.convert 3344 arch_lib_conv = _arch_lib_conv.convert 3345 else: 3346 arch_file_conv = _arch_file_conv.convert_as_dict 3347 arch_lib_conv = _arch_lib_conv.convert_as_dict 3348 3349 input_size_refs = { 3350 ipt.name: { 3351 a: s 3352 for a, s in zip( 3353 ipt.axes, 3354 ( 3355 ipt.shape.min 3356 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3357 else ipt.shape 3358 ), 3359 ) 3360 } 3361 for ipt in src.inputs 3362 if ipt.shape 3363 } 3364 output_size_refs = { 3365 **{ 3366 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3367 for out in src.outputs 3368 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3369 }, 3370 **input_size_refs, 3371 } 3372 3373 return tgt( 3374 attachments=( 3375 [] 3376 if src.attachments is None 3377 else [FileDescr(source=f) for f in src.attachments.files] 3378 ), 3379 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 3380 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 3381 config=src.config, # pyright: ignore[reportArgumentType] 3382 covers=src.covers, 3383 description=src.description, 3384 documentation=src.documentation, 3385 format_version="0.5.6", 3386 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3387 icon=src.icon, 3388 id=None if src.id is None else ModelId(src.id), 3389 id_emoji=src.id_emoji, 3390 license=src.license, # type: ignore 3391 links=src.links, 3392 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 3393 name=name, 3394 tags=src.tags, 3395 type=src.type, 3396 uploader=src.uploader, 3397 version=src.version, 3398 inputs=[ # pyright: ignore[reportArgumentType] 3399 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3400 for ipt, tt, st in zip( 3401 src.inputs, 3402 src.test_inputs, 3403 src.sample_inputs or [None] * len(src.test_inputs), 3404 ) 3405 ], 3406 outputs=[ # pyright: ignore[reportArgumentType] 3407 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3408 for out, tt, st in zip( 3409 src.outputs, 3410 src.test_outputs, 3411 src.sample_outputs or [None] * len(src.test_outputs), 3412 ) 3413 ], 3414 parent=( 3415 None 3416 if src.parent is None 3417 else LinkedModel( 3418 id=ModelId( 3419 str(src.parent.id) 3420 + ( 3421 "" 3422 if src.parent.version_number is None 3423 else f"/{src.parent.version_number}" 3424 ) 3425 ) 3426 ) 3427 ), 3428 training_data=( 3429 None 3430 if src.training_data is None 3431 else ( 3432 LinkedDataset( 3433 id=DatasetId( 3434 str(src.training_data.id) 3435 + ( 3436 "" 3437 if src.training_data.version_number is None 3438 else f"/{src.training_data.version_number}" 3439 ) 3440 ) 3441 ) 3442 if isinstance(src.training_data, LinkedDataset02) 3443 else src.training_data 3444 ) 3445 ), 3446 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 3447 run_mode=src.run_mode, 3448 timestamp=src.timestamp, 3449 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3450 keras_hdf5=(w := src.weights.keras_hdf5) 3451 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3452 authors=conv_authors(w.authors), 3453 source=w.source, 3454 tensorflow_version=w.tensorflow_version or Version("1.15"), 3455 parent=w.parent, 3456 ), 3457 onnx=(w := src.weights.onnx) 3458 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3459 source=w.source, 3460 authors=conv_authors(w.authors), 3461 parent=w.parent, 3462 opset_version=w.opset_version or 15, 3463 ), 3464 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3465 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3466 source=w.source, 3467 authors=conv_authors(w.authors), 3468 parent=w.parent, 3469 architecture=( 3470 arch_file_conv( 3471 w.architecture, 3472 w.architecture_sha256, 3473 w.kwargs, 3474 ) 3475 if isinstance(w.architecture, _CallableFromFile_v0_4) 3476 else arch_lib_conv(w.architecture, w.kwargs) 3477 ), 3478 pytorch_version=w.pytorch_version or Version("1.10"), 3479 dependencies=( 3480 None 3481 if w.dependencies is None 3482 else (FileDescr if TYPE_CHECKING else dict)( 3483 source=cast( 3484 FileSource, 3485 str(deps := w.dependencies)[ 3486 ( 3487 len("conda:") 3488 if str(deps).startswith("conda:") 3489 else 0 3490 ) : 3491 ], 3492 ) 3493 ) 3494 ), 3495 ), 3496 tensorflow_js=(w := src.weights.tensorflow_js) 3497 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3498 source=w.source, 3499 authors=conv_authors(w.authors), 3500 parent=w.parent, 3501 tensorflow_version=w.tensorflow_version or Version("1.15"), 3502 ), 3503 tensorflow_saved_model_bundle=( 3504 w := src.weights.tensorflow_saved_model_bundle 3505 ) 3506 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3507 authors=conv_authors(w.authors), 3508 parent=w.parent, 3509 source=w.source, 3510 tensorflow_version=w.tensorflow_version or Version("1.15"), 3511 dependencies=( 3512 None 3513 if w.dependencies is None 3514 else (FileDescr if TYPE_CHECKING else dict)( 3515 source=cast( 3516 FileSource, 3517 ( 3518 str(w.dependencies)[len("conda:") :] 3519 if str(w.dependencies).startswith("conda:") 3520 else str(w.dependencies) 3521 ), 3522 ) 3523 ) 3524 ), 3525 ), 3526 torchscript=(w := src.weights.torchscript) 3527 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3528 source=w.source, 3529 authors=conv_authors(w.authors), 3530 parent=w.parent, 3531 pytorch_version=w.pytorch_version or Version("1.10"), 3532 ), 3533 ), 3534 ) 3535 3536 3537_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3538 3539 3540# create better cover images for 3d data and non-image outputs 3541def generate_covers( 3542 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3543 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3544) -> List[Path]: 3545 def squeeze( 3546 data: NDArray[Any], axes: Sequence[AnyAxis] 3547 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3548 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3549 if data.ndim != len(axes): 3550 raise ValueError( 3551 f"tensor shape {data.shape} does not match described axes" 3552 + f" {[a.id for a in axes]}" 3553 ) 3554 3555 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3556 return data.squeeze(), axes 3557 3558 def normalize( 3559 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3560 ) -> NDArray[np.float32]: 3561 data = data.astype("float32") 3562 data -= data.min(axis=axis, keepdims=True) 3563 data /= data.max(axis=axis, keepdims=True) + eps 3564 return data 3565 3566 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3567 original_shape = data.shape 3568 original_axes = list(axes) 3569 data, axes = squeeze(data, axes) 3570 3571 # take slice fom any batch or index axis if needed 3572 # and convert the first channel axis and take a slice from any additional channel axes 3573 slices: Tuple[slice, ...] = () 3574 ndim = data.ndim 3575 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3576 has_c_axis = False 3577 for i, a in enumerate(axes): 3578 s = data.shape[i] 3579 assert s > 1 3580 if ( 3581 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3582 and ndim > ndim_need 3583 ): 3584 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3585 ndim -= 1 3586 elif isinstance(a, ChannelAxis): 3587 if has_c_axis: 3588 # second channel axis 3589 data = data[slices + (slice(0, 1),)] 3590 ndim -= 1 3591 else: 3592 has_c_axis = True 3593 if s == 2: 3594 # visualize two channels with cyan and magenta 3595 data = np.concatenate( 3596 [ 3597 data[slices + (slice(1, 2),)], 3598 data[slices + (slice(0, 1),)], 3599 ( 3600 data[slices + (slice(0, 1),)] 3601 + data[slices + (slice(1, 2),)] 3602 ) 3603 / 2, # TODO: take maximum instead? 3604 ], 3605 axis=i, 3606 ) 3607 elif data.shape[i] == 3: 3608 pass # visualize 3 channels as RGB 3609 else: 3610 # visualize first 3 channels as RGB 3611 data = data[slices + (slice(3),)] 3612 3613 assert data.shape[i] == 3 3614 3615 slices += (slice(None),) 3616 3617 data, axes = squeeze(data, axes) 3618 assert len(axes) == ndim 3619 # take slice from z axis if needed 3620 slices = () 3621 if ndim > ndim_need: 3622 for i, a in enumerate(axes): 3623 s = data.shape[i] 3624 if a.id == AxisId("z"): 3625 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3626 data, axes = squeeze(data, axes) 3627 ndim -= 1 3628 break 3629 3630 slices += (slice(None),) 3631 3632 # take slice from any space or time axis 3633 slices = () 3634 3635 for i, a in enumerate(axes): 3636 if ndim <= ndim_need: 3637 break 3638 3639 s = data.shape[i] 3640 assert s > 1 3641 if isinstance( 3642 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3643 ): 3644 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3645 ndim -= 1 3646 3647 slices += (slice(None),) 3648 3649 del slices 3650 data, axes = squeeze(data, axes) 3651 assert len(axes) == ndim 3652 3653 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 3654 raise ValueError( 3655 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 3656 ) 3657 3658 if not has_c_axis: 3659 assert ndim == 2 3660 data = np.repeat(data[:, :, None], 3, axis=2) 3661 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3662 ndim += 1 3663 3664 assert ndim == 3 3665 3666 # transpose axis order such that longest axis comes first... 3667 axis_order: List[int] = list(np.argsort(list(data.shape))) 3668 axis_order.reverse() 3669 # ... and channel axis is last 3670 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3671 axis_order.append(axis_order.pop(c)) 3672 axes = [axes[ao] for ao in axis_order] 3673 data = data.transpose(axis_order) 3674 3675 # h, w = data.shape[:2] 3676 # if h / w in (1.0 or 2.0): 3677 # pass 3678 # elif h / w < 2: 3679 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3680 3681 norm_along = ( 3682 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3683 ) 3684 # normalize the data and map to 8 bit 3685 data = normalize(data, norm_along) 3686 data = (data * 255).astype("uint8") 3687 3688 return data 3689 3690 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3691 assert im0.dtype == im1.dtype == np.uint8 3692 assert im0.shape == im1.shape 3693 assert im0.ndim == 3 3694 N, M, C = im0.shape 3695 assert C == 3 3696 out = np.ones((N, M, C), dtype="uint8") 3697 for c in range(C): 3698 outc = np.tril(im0[..., c]) 3699 mask = outc == 0 3700 outc[mask] = np.triu(im1[..., c])[mask] 3701 out[..., c] = outc 3702 3703 return out 3704 3705 if not inputs: 3706 raise ValueError("Missing test input tensor for cover generation.") 3707 3708 if not outputs: 3709 raise ValueError("Missing test output tensor for cover generation.") 3710 3711 ipt_descr, ipt = inputs[0] 3712 out_descr, out = outputs[0] 3713 3714 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3715 out_img = to_2d_image(out, out_descr.axes) 3716 3717 cover_folder = Path(mkdtemp()) 3718 if ipt_img.shape == out_img.shape: 3719 covers = [cover_folder / "cover.png"] 3720 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3721 else: 3722 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3723 imwrite(covers[0], ipt_img) 3724 imwrite(covers[1], out_img) 3725 3726 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
230class TensorId(LowerCaseIdentifier): 231 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 232 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 233 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
246class AxisId(LowerCaseIdentifier): 247 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 248 Annotated[ 249 LowerCaseIdentifierAnno, 250 MaxLen(16), 251 AfterValidator(_normalize_axis_id), 252 ] 253 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize.
299class ParameterizedSize(Node): 300 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 301 302 - **min** and **step** are given by the model description. 303 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 304 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 305 This allows to adjust the axis size more generically. 306 """ 307 308 N: ClassVar[Type[int]] = ParameterizedSize_N 309 """Positive integer to parameterize this axis""" 310 311 min: Annotated[int, Gt(0)] 312 step: Annotated[int, Gt(0)] 313 314 def validate_size(self, size: int) -> int: 315 if size < self.min: 316 raise ValueError(f"size {size} < {self.min}") 317 if (size - self.min) % self.step != 0: 318 raise ValueError( 319 f"axis of size {size} is not parameterized by `min + n*step` =" 320 + f" `{self.min} + n*{self.step}`" 321 ) 322 323 return size 324 325 def get_size(self, n: ParameterizedSize_N) -> int: 326 return self.min + self.step * n 327 328 def get_n(self, s: int) -> ParameterizedSize_N: 329 """return smallest n parameterizing a size greater or equal than `s`""" 330 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
314 def validate_size(self, size: int) -> int: 315 if size < self.min: 316 raise ValueError(f"size {size} < {self.min}") 317 if (size - self.min) % self.step != 0: 318 raise ValueError( 319 f"axis of size {size} is not parameterized by `min + n*step` =" 320 + f" `{self.min} + n*{self.step}`" 321 ) 322 323 return size
333class DataDependentSize(Node): 334 min: Annotated[int, Gt(0)] = 1 335 max: Annotated[Optional[int], Gt(1)] = None 336 337 @model_validator(mode="after") 338 def _validate_max_gt_min(self): 339 if self.max is not None and self.min >= self.max: 340 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 341 342 return self 343 344 def validate_size(self, size: int) -> int: 345 if size < self.min: 346 raise ValueError(f"size {size} < {self.min}") 347 348 if self.max is not None and size > self.max: 349 raise ValueError(f"size {size} > {self.max}") 350 351 return size
354class SizeReference(Node): 355 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 356 357 `axis.size = reference.size * reference.scale / axis.scale + offset` 358 359 Note: 360 1. The axis and the referenced axis need to have the same unit (or no unit). 361 2. Batch axes may not be referenced. 362 3. Fractions are rounded down. 363 4. If the reference axis is `concatenable` the referencing axis is assumed to be 364 `concatenable` as well with the same block order. 365 366 Example: 367 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 368 Let's assume that we want to express the image height h in relation to its width w 369 instead of only accepting input images of exactly 100*49 pixels 370 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 371 372 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 373 >>> h = SpaceInputAxis( 374 ... id=AxisId("h"), 375 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 376 ... unit="millimeter", 377 ... scale=4, 378 ... ) 379 >>> print(h.size.get_size(h, w)) 380 49 381 382 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 383 """ 384 385 tensor_id: TensorId 386 """tensor id of the reference axis""" 387 388 axis_id: AxisId 389 """axis id of the reference axis""" 390 391 offset: StrictInt = 0 392 393 def get_size( 394 self, 395 axis: Union[ 396 ChannelAxis, 397 IndexInputAxis, 398 IndexOutputAxis, 399 TimeInputAxis, 400 SpaceInputAxis, 401 TimeOutputAxis, 402 TimeOutputAxisWithHalo, 403 SpaceOutputAxis, 404 SpaceOutputAxisWithHalo, 405 ], 406 ref_axis: Union[ 407 ChannelAxis, 408 IndexInputAxis, 409 IndexOutputAxis, 410 TimeInputAxis, 411 SpaceInputAxis, 412 TimeOutputAxis, 413 TimeOutputAxisWithHalo, 414 SpaceOutputAxis, 415 SpaceOutputAxisWithHalo, 416 ], 417 n: ParameterizedSize_N = 0, 418 ref_size: Optional[int] = None, 419 ): 420 """Compute the concrete size for a given axis and its reference axis. 421 422 Args: 423 axis: The axis this `SizeReference` is the size of. 424 ref_axis: The reference axis to compute the size from. 425 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 426 and no fixed **ref_size** is given, 427 **n** is used to compute the size of the parameterized **ref_axis**. 428 ref_size: Overwrite the reference size instead of deriving it from 429 **ref_axis** 430 (**ref_axis.scale** is still used; any given **n** is ignored). 431 """ 432 assert axis.size == self, ( 433 "Given `axis.size` is not defined by this `SizeReference`" 434 ) 435 436 assert ref_axis.id == self.axis_id, ( 437 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 438 ) 439 440 assert axis.unit == ref_axis.unit, ( 441 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 442 f" but {axis.unit}!={ref_axis.unit}" 443 ) 444 if ref_size is None: 445 if isinstance(ref_axis.size, (int, float)): 446 ref_size = ref_axis.size 447 elif isinstance(ref_axis.size, ParameterizedSize): 448 ref_size = ref_axis.size.get_size(n) 449 elif isinstance(ref_axis.size, DataDependentSize): 450 raise ValueError( 451 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 452 ) 453 elif isinstance(ref_axis.size, SizeReference): 454 raise ValueError( 455 "Reference axis referenced in `SizeReference` may not be sized by a" 456 + " `SizeReference` itself." 457 ) 458 else: 459 assert_never(ref_axis.size) 460 461 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 462 463 @staticmethod 464 def _get_unit( 465 axis: Union[ 466 ChannelAxis, 467 IndexInputAxis, 468 IndexOutputAxis, 469 TimeInputAxis, 470 SpaceInputAxis, 471 TimeOutputAxis, 472 TimeOutputAxisWithHalo, 473 SpaceOutputAxis, 474 SpaceOutputAxisWithHalo, 475 ], 476 ): 477 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenablethe referencing axis is assumed to beconcatenableas well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
393 def get_size( 394 self, 395 axis: Union[ 396 ChannelAxis, 397 IndexInputAxis, 398 IndexOutputAxis, 399 TimeInputAxis, 400 SpaceInputAxis, 401 TimeOutputAxis, 402 TimeOutputAxisWithHalo, 403 SpaceOutputAxis, 404 SpaceOutputAxisWithHalo, 405 ], 406 ref_axis: Union[ 407 ChannelAxis, 408 IndexInputAxis, 409 IndexOutputAxis, 410 TimeInputAxis, 411 SpaceInputAxis, 412 TimeOutputAxis, 413 TimeOutputAxisWithHalo, 414 SpaceOutputAxis, 415 SpaceOutputAxisWithHalo, 416 ], 417 n: ParameterizedSize_N = 0, 418 ref_size: Optional[int] = None, 419 ): 420 """Compute the concrete size for a given axis and its reference axis. 421 422 Args: 423 axis: The axis this `SizeReference` is the size of. 424 ref_axis: The reference axis to compute the size from. 425 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 426 and no fixed **ref_size** is given, 427 **n** is used to compute the size of the parameterized **ref_axis**. 428 ref_size: Overwrite the reference size instead of deriving it from 429 **ref_axis** 430 (**ref_axis.scale** is still used; any given **n** is ignored). 431 """ 432 assert axis.size == self, ( 433 "Given `axis.size` is not defined by this `SizeReference`" 434 ) 435 436 assert ref_axis.id == self.axis_id, ( 437 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 438 ) 439 440 assert axis.unit == ref_axis.unit, ( 441 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 442 f" but {axis.unit}!={ref_axis.unit}" 443 ) 444 if ref_size is None: 445 if isinstance(ref_axis.size, (int, float)): 446 ref_size = ref_axis.size 447 elif isinstance(ref_axis.size, ParameterizedSize): 448 ref_size = ref_axis.size.get_size(n) 449 elif isinstance(ref_axis.size, DataDependentSize): 450 raise ValueError( 451 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 452 ) 453 elif isinstance(ref_axis.size, SizeReference): 454 raise ValueError( 455 "Reference axis referenced in `SizeReference` may not be sized by a" 456 + " `SizeReference` itself." 457 ) 458 else: 459 assert_never(ref_axis.size) 460 461 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReferenceis the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
488class WithHalo(Node): 489 halo: Annotated[int, Ge(1)] 490 """The halo should be cropped from the output tensor to avoid boundary effects. 491 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 492 To document a halo that is already cropped by the model use `size.offset` instead.""" 493 494 size: Annotated[ 495 SizeReference, 496 Field( 497 examples=[ 498 10, 499 SizeReference( 500 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 501 ).model_dump(mode="json"), 502 ] 503 ), 504 ] 505 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo.
To document a halo that is already cropped by the model use size.offset instead.
reference to another axis with an optional offset (see SizeReference)
511class BatchAxis(AxisBase): 512 implemented_type: ClassVar[Literal["batch"]] = "batch" 513 if TYPE_CHECKING: 514 type: Literal["batch"] = "batch" 515 else: 516 type: Literal["batch"] 517 518 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 519 size: Optional[Literal[1]] = None 520 """The batch size may be fixed to 1, 521 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 522 523 @property 524 def scale(self): 525 return 1.0 526 527 @property 528 def concatenable(self): 529 return True 530 531 @property 532 def unit(self): 533 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Inherited Members
536class ChannelAxis(AxisBase): 537 implemented_type: ClassVar[Literal["channel"]] = "channel" 538 if TYPE_CHECKING: 539 type: Literal["channel"] = "channel" 540 else: 541 type: Literal["channel"] 542 543 id: NonBatchAxisId = AxisId("channel") 544 545 channel_names: NotEmpty[List[Identifier]] 546 547 @property 548 def size(self) -> int: 549 return len(self.channel_names) 550 551 @property 552 def concatenable(self): 553 return False 554 555 @property 556 def scale(self) -> float: 557 return 1.0 558 559 @property 560 def unit(self): 561 return None
Inherited Members
564class IndexAxisBase(AxisBase): 565 implemented_type: ClassVar[Literal["index"]] = "index" 566 if TYPE_CHECKING: 567 type: Literal["index"] = "index" 568 else: 569 type: Literal["index"] 570 571 id: NonBatchAxisId = AxisId("index") 572 573 @property 574 def scale(self) -> float: 575 return 1.0 576 577 @property 578 def unit(self): 579 return None
Inherited Members
602class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 603 concatenable: bool = False 604 """If a model has a `concatenable` input axis, it can be processed blockwise, 605 splitting a longer sample axis into blocks matching its input tensor description. 606 Output axes are concatenable if they have a `SizeReference` to a concatenable 607 input axis. 608 """
If a model has a concatenable input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference to a concatenable
input axis.
Inherited Members
611class IndexOutputAxis(IndexAxisBase): 612 size: Annotated[ 613 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 614 Field( 615 examples=[ 616 10, 617 SizeReference( 618 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 619 ).model_dump(mode="json"), 620 ] 621 ), 622 ] 623 """The size/length of this axis can be specified as 624 - fixed integer 625 - reference to another axis with an optional offset (`SizeReference`) 626 - data dependent size using `DataDependentSize` (size is only known after model inference) 627 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference) - data dependent size using
DataDependentSize(size is only known after model inference)
Inherited Members
630class TimeAxisBase(AxisBase): 631 implemented_type: ClassVar[Literal["time"]] = "time" 632 if TYPE_CHECKING: 633 type: Literal["time"] = "time" 634 else: 635 type: Literal["time"] 636 637 id: NonBatchAxisId = AxisId("time") 638 unit: Optional[TimeUnit] = None 639 scale: Annotated[float, Gt(0)] = 1.0
Inherited Members
642class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 643 concatenable: bool = False 644 """If a model has a `concatenable` input axis, it can be processed blockwise, 645 splitting a longer sample axis into blocks matching its input tensor description. 646 Output axes are concatenable if they have a `SizeReference` to a concatenable 647 input axis. 648 """
If a model has a concatenable input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference to a concatenable
input axis.
Inherited Members
651class SpaceAxisBase(AxisBase): 652 implemented_type: ClassVar[Literal["space"]] = "space" 653 if TYPE_CHECKING: 654 type: Literal["space"] = "space" 655 else: 656 type: Literal["space"] 657 658 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 659 unit: Optional[SpaceUnit] = None 660 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Inherited Members
663class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 664 concatenable: bool = False 665 """If a model has a `concatenable` input axis, it can be processed blockwise, 666 splitting a longer sample axis into blocks matching its input tensor description. 667 Output axes are concatenable if they have a `SizeReference` to a concatenable 668 input axis. 669 """
If a model has a concatenable input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference to a concatenable
input axis.
Inherited Members
intended for isinstance comparisons in py<3.10
Inherited Members
Inherited Members
Inherited Members
Inherited Members
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
791class NominalOrOrdinalDataDescr(Node): 792 values: TVs 793 """A fixed set of nominal or an ascending sequence of ordinal values. 794 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 795 String `values` are interpreted as labels for tensor values 0, ..., N. 796 Note: as YAML 1.2 does not natively support a "set" datatype, 797 nominal values should be given as a sequence (aka list/array) as well. 798 """ 799 800 type: Annotated[ 801 NominalOrOrdinalDType, 802 Field( 803 examples=[ 804 "float32", 805 "uint8", 806 "uint16", 807 "int64", 808 "bool", 809 ], 810 ), 811 ] = "uint8" 812 813 @model_validator(mode="after") 814 def _validate_values_match_type( 815 self, 816 ) -> Self: 817 incompatible: List[Any] = [] 818 for v in self.values: 819 if self.type == "bool": 820 if not isinstance(v, bool): 821 incompatible.append(v) 822 elif self.type in DTYPE_LIMITS: 823 if ( 824 isinstance(v, (int, float)) 825 and ( 826 v < DTYPE_LIMITS[self.type].min 827 or v > DTYPE_LIMITS[self.type].max 828 ) 829 or (isinstance(v, str) and "uint" not in self.type) 830 or (isinstance(v, float) and "int" in self.type) 831 ): 832 incompatible.append(v) 833 else: 834 incompatible.append(v) 835 836 if len(incompatible) == 5: 837 incompatible.append("...") 838 break 839 840 if incompatible: 841 raise ValueError( 842 f"data type '{self.type}' incompatible with values {incompatible}" 843 ) 844 845 return self 846 847 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 848 849 @property 850 def range(self): 851 if isinstance(self.values[0], str): 852 return 0, len(self.values) - 1 853 else: 854 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type is required to be an unsigend integer type, e.g. 'uint8'.
String values are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
871class IntervalOrRatioDataDescr(Node): 872 type: Annotated[ # TODO: rename to dtype 873 IntervalOrRatioDType, 874 Field( 875 examples=["float32", "float64", "uint8", "uint16"], 876 ), 877 ] = "float32" 878 range: Tuple[Optional[float], Optional[float]] = ( 879 None, 880 None, 881 ) 882 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 883 `None` corresponds to min/max of what can be expressed by **type**.""" 884 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 885 scale: float = 1.0 886 """Scale for data on an interval (or ratio) scale.""" 887 offset: Optional[float] = None 888 """Offset for data on a ratio scale.""" 889 890 @model_validator(mode="before") 891 def _replace_inf(cls, data: Any): 892 if is_dict(data): 893 if "range" in data and is_sequence(data["range"]): 894 forbidden = ( 895 "inf", 896 "-inf", 897 ".inf", 898 "-.inf", 899 float("inf"), 900 float("-inf"), 901 ) 902 if any(v in forbidden for v in data["range"]): 903 issue_warning("replaced 'inf' value", value=data["range"]) 904 905 data["range"] = tuple( 906 (None if v in forbidden else v) for v in data["range"] 907 ) 908 909 return data
Tuple (minimum, maximum) specifying the allowed range of the data in this tensor.
None corresponds to min/max of what can be expressed by type.
processing base class
919class BinarizeKwargs(ProcessingKwargs): 920 """key word arguments for `BinarizeDescr`""" 921 922 threshold: float 923 """The fixed threshold"""
key word arguments for BinarizeDescr
Inherited Members
926class BinarizeAlongAxisKwargs(ProcessingKwargs): 927 """key word arguments for `BinarizeDescr`""" 928 929 threshold: NotEmpty[List[float]] 930 """The fixed threshold values along `axis`""" 931 932 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 933 """The `threshold` axis"""
key word arguments for BinarizeDescr
The fixed threshold values along axis
The threshold axis
Inherited Members
936class BinarizeDescr(ProcessingDescrBase): 937 """Binarize the tensor with a fixed threshold. 938 939 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 940 will be set to one, values below the threshold to zero. 941 942 Examples: 943 - in YAML 944 ```yaml 945 postprocessing: 946 - id: binarize 947 kwargs: 948 axis: 'channel' 949 threshold: [0.25, 0.5, 0.75] 950 ``` 951 - in Python: 952 >>> postprocessing = [BinarizeDescr( 953 ... kwargs=BinarizeAlongAxisKwargs( 954 ... axis=AxisId('channel'), 955 ... threshold=[0.25, 0.5, 0.75], 956 ... ) 957 ... )] 958 """ 959 960 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 961 if TYPE_CHECKING: 962 id: Literal["binarize"] = "binarize" 963 else: 964 id: Literal["binarize"] 965 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
968class ClipDescr(ProcessingDescrBase): 969 """Set tensor values below min to min and above max to max. 970 971 See `ScaleRangeDescr` for examples. 972 """ 973 974 implemented_id: ClassVar[Literal["clip"]] = "clip" 975 if TYPE_CHECKING: 976 id: Literal["clip"] = "clip" 977 else: 978 id: Literal["clip"] 979 980 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr for examples.
983class EnsureDtypeKwargs(ProcessingKwargs): 984 """key word arguments for `EnsureDtypeDescr`""" 985 986 dtype: Literal[ 987 "float32", 988 "float64", 989 "uint8", 990 "int8", 991 "uint16", 992 "int16", 993 "uint32", 994 "int32", 995 "uint64", 996 "int64", 997 "bool", 998 ]
key word arguments for EnsureDtypeDescr
Inherited Members
1001class EnsureDtypeDescr(ProcessingDescrBase): 1002 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1003 1004 This can for example be used to ensure the inner neural network model gets a 1005 different input tensor data type than the fully described bioimage.io model does. 1006 1007 Examples: 1008 The described bioimage.io model (incl. preprocessing) accepts any 1009 float32-compatible tensor, normalizes it with percentiles and clipping and then 1010 casts it to uint8, which is what the neural network in this example expects. 1011 - in YAML 1012 ```yaml 1013 inputs: 1014 - data: 1015 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1016 preprocessing: 1017 - id: scale_range 1018 kwargs: 1019 axes: ['y', 'x'] 1020 max_percentile: 99.8 1021 min_percentile: 5.0 1022 - id: clip 1023 kwargs: 1024 min: 0.0 1025 max: 1.0 1026 - id: ensure_dtype # the neural network of the model requires uint8 1027 kwargs: 1028 dtype: uint8 1029 ``` 1030 - in Python: 1031 >>> preprocessing = [ 1032 ... ScaleRangeDescr( 1033 ... kwargs=ScaleRangeKwargs( 1034 ... axes= (AxisId('y'), AxisId('x')), 1035 ... max_percentile= 99.8, 1036 ... min_percentile= 5.0, 1037 ... ) 1038 ... ), 1039 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1040 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1041 ... ] 1042 """ 1043 1044 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1045 if TYPE_CHECKING: 1046 id: Literal["ensure_dtype"] = "ensure_dtype" 1047 else: 1048 id: Literal["ensure_dtype"] 1049 1050 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype # the neural network of the model requires uint8 kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
1053class ScaleLinearKwargs(ProcessingKwargs): 1054 """Key word arguments for `ScaleLinearDescr`""" 1055 1056 gain: float = 1.0 1057 """multiplicative factor""" 1058 1059 offset: float = 0.0 1060 """additive term""" 1061 1062 @model_validator(mode="after") 1063 def _validate(self) -> Self: 1064 if self.gain == 1.0 and self.offset == 0.0: 1065 raise ValueError( 1066 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1067 + " != 0.0." 1068 ) 1069 1070 return self
Key word arguments for ScaleLinearDescr
Inherited Members
1073class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1074 """Key word arguments for `ScaleLinearDescr`""" 1075 1076 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1077 """The axis of gain and offset values.""" 1078 1079 gain: Union[float, NotEmpty[List[float]]] = 1.0 1080 """multiplicative factor""" 1081 1082 offset: Union[float, NotEmpty[List[float]]] = 0.0 1083 """additive term""" 1084 1085 @model_validator(mode="after") 1086 def _validate(self) -> Self: 1087 if isinstance(self.gain, list): 1088 if isinstance(self.offset, list): 1089 if len(self.gain) != len(self.offset): 1090 raise ValueError( 1091 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1092 ) 1093 else: 1094 self.offset = [float(self.offset)] * len(self.gain) 1095 elif isinstance(self.offset, list): 1096 self.gain = [float(self.gain)] * len(self.offset) 1097 else: 1098 raise ValueError( 1099 "Do not specify an `axis` for scalar gain and offset values." 1100 ) 1101 1102 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1103 raise ValueError( 1104 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1105 + " != 0.0." 1106 ) 1107 1108 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
Inherited Members
1111class ScaleLinearDescr(ProcessingDescrBase): 1112 """Fixed linear scaling. 1113 1114 Examples: 1115 1. Scale with scalar gain and offset 1116 - in YAML 1117 ```yaml 1118 preprocessing: 1119 - id: scale_linear 1120 kwargs: 1121 gain: 2.0 1122 offset: 3.0 1123 ``` 1124 - in Python: 1125 >>> preprocessing = [ 1126 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1127 ... ] 1128 1129 2. Independent scaling along an axis 1130 - in YAML 1131 ```yaml 1132 preprocessing: 1133 - id: scale_linear 1134 kwargs: 1135 axis: 'channel' 1136 gain: [1.0, 2.0, 3.0] 1137 ``` 1138 - in Python: 1139 >>> preprocessing = [ 1140 ... ScaleLinearDescr( 1141 ... kwargs=ScaleLinearAlongAxisKwargs( 1142 ... axis=AxisId("channel"), 1143 ... gain=[1.0, 2.0, 3.0], 1144 ... ) 1145 ... ) 1146 ... ] 1147 1148 """ 1149 1150 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1151 if TYPE_CHECKING: 1152 id: Literal["scale_linear"] = "scale_linear" 1153 else: 1154 id: Literal["scale_linear"] 1155 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
1158class SigmoidDescr(ProcessingDescrBase): 1159 """The logistic sigmoid function, a.k.a. expit function. 1160 1161 Examples: 1162 - in YAML 1163 ```yaml 1164 postprocessing: 1165 - id: sigmoid 1166 ``` 1167 - in Python: 1168 >>> postprocessing = [SigmoidDescr()] 1169 """ 1170 1171 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1172 if TYPE_CHECKING: 1173 id: Literal["sigmoid"] = "sigmoid" 1174 else: 1175 id: Literal["sigmoid"] 1176 1177 @property 1178 def kwargs(self) -> ProcessingKwargs: 1179 """empty kwargs""" 1180 return ProcessingKwargs()
The logistic sigmoid function, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1183class SoftmaxKwargs(ProcessingKwargs): 1184 """key word arguments for `SoftmaxDescr`""" 1185 1186 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1187 """The axis to apply the softmax function along. 1188 Note: 1189 Defaults to 'channel' axis 1190 (which may not exist, in which case 1191 a different axis id has to be specified). 1192 """
key word arguments for SoftmaxDescr
The axis to apply the softmax function along.
Note:
Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).
Inherited Members
1195class SoftmaxDescr(ProcessingDescrBase): 1196 """The softmax function. 1197 1198 Examples: 1199 - in YAML 1200 ```yaml 1201 postprocessing: 1202 - id: softmax 1203 kwargs: 1204 axis: channel 1205 ``` 1206 - in Python: 1207 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1208 """ 1209 1210 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1211 if TYPE_CHECKING: 1212 id: Literal["softmax"] = "softmax" 1213 else: 1214 id: Literal["softmax"] 1215 1216 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
The softmax function.
Examples:
- in YAML
postprocessing:
- id: softmax
kwargs:
axis: channel
- in Python:
>>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1220 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1221 1222 mean: float 1223 """The mean value to normalize with.""" 1224 1225 std: Annotated[float, Ge(1e-6)] 1226 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value to normalize with.
Inherited Members
1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1230 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1231 1232 mean: NotEmpty[List[float]] 1233 """The mean value(s) to normalize with.""" 1234 1235 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1236 """The standard deviation value(s) to normalize with. 1237 Size must match `mean` values.""" 1238 1239 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1240 """The axis of the mean/std values to normalize each entry along that dimension 1241 separately.""" 1242 1243 @model_validator(mode="after") 1244 def _mean_and_std_match(self) -> Self: 1245 if len(self.mean) != len(self.std): 1246 raise ValueError( 1247 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1248 + " must match." 1249 ) 1250 1251 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The mean value(s) to normalize with.
The standard deviation value(s) to normalize with.
Size must match mean values.
The axis of the mean/std values to normalize each entry along that dimension separately.
Inherited Members
1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1255 """Subtract a given mean and divide by the standard deviation. 1256 1257 Normalize with fixed, precomputed values for 1258 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1259 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1260 axes. 1261 1262 Examples: 1263 1. scalar value for whole tensor 1264 - in YAML 1265 ```yaml 1266 preprocessing: 1267 - id: fixed_zero_mean_unit_variance 1268 kwargs: 1269 mean: 103.5 1270 std: 13.7 1271 ``` 1272 - in Python 1273 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1274 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1275 ... )] 1276 1277 2. independently along an axis 1278 - in YAML 1279 ```yaml 1280 preprocessing: 1281 - id: fixed_zero_mean_unit_variance 1282 kwargs: 1283 axis: channel 1284 mean: [101.5, 102.5, 103.5] 1285 std: [11.7, 12.7, 13.7] 1286 ``` 1287 - in Python 1288 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1289 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1290 ... axis=AxisId("channel"), 1291 ... mean=[101.5, 102.5, 103.5], 1292 ... std=[11.7, 12.7, 13.7], 1293 ... ) 1294 ... )] 1295 """ 1296 1297 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1298 "fixed_zero_mean_unit_variance" 1299 ) 1300 if TYPE_CHECKING: 1301 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1302 else: 1303 id: Literal["fixed_zero_mean_unit_variance"] 1304 1305 kwargs: Union[ 1306 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1307 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1311 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1312 1313 axes: Annotated[ 1314 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1315 ] = None 1316 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1317 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1318 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1319 To normalize each sample independently leave out the 'batch' axis. 1320 Default: Scale all axes jointly.""" 1321 1322 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1323 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y').
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
epsilon for numeric stability: out = (tensor - mean) / (std + eps).
Inherited Members
1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1327 """Subtract mean and divide by variance. 1328 1329 Examples: 1330 Subtract tensor mean and variance 1331 - in YAML 1332 ```yaml 1333 preprocessing: 1334 - id: zero_mean_unit_variance 1335 ``` 1336 - in Python 1337 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1338 """ 1339 1340 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1341 "zero_mean_unit_variance" 1342 ) 1343 if TYPE_CHECKING: 1344 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1345 else: 1346 id: Literal["zero_mean_unit_variance"] 1347 1348 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1349 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1350 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1353class ScaleRangeKwargs(ProcessingKwargs): 1354 """key word arguments for `ScaleRangeDescr` 1355 1356 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1357 this processing step normalizes data to the [0, 1] intervall. 1358 For other percentiles the normalized values will partially be outside the [0, 1] 1359 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1360 normalized values to a range. 1361 """ 1362 1363 axes: Annotated[ 1364 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1365 ] = None 1366 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1367 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1368 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1369 To normalize samples independently, leave out the "batch" axis. 1370 Default: Scale all axes jointly.""" 1371 1372 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1373 """The lower percentile used to determine the value to align with zero.""" 1374 1375 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1376 """The upper percentile used to determine the value to align with one. 1377 Has to be bigger than `min_percentile`. 1378 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1379 accepting percentiles specified in the range 0.0 to 1.0.""" 1380 1381 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1382 """Epsilon for numeric stability. 1383 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1384 with `v_lower,v_upper` values at the respective percentiles.""" 1385 1386 reference_tensor: Optional[TensorId] = None 1387 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1388 For any tensor in `inputs` only input tensor references are allowed.""" 1389 1390 @field_validator("max_percentile", mode="after") 1391 @classmethod 1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1393 if (min_p := info.data["min_percentile"]) >= value: 1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1395 1396 return value
key word arguments for ScaleRangeDescr
For min_percentile=0.0 (the default) and max_percentile=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange followed by ClipDescr if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y').
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps);
with v_lower,v_upper values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs only input tensor references are allowed.
1390 @field_validator("max_percentile", mode="after") 1391 @classmethod 1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1393 if (min_p := info.data["min_percentile"]) >= value: 1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1395 1396 return value
Inherited Members
1399class ScaleRangeDescr(ProcessingDescrBase): 1400 """Scale with percentiles. 1401 1402 Examples: 1403 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1404 - in YAML 1405 ```yaml 1406 preprocessing: 1407 - id: scale_range 1408 kwargs: 1409 axes: ['y', 'x'] 1410 max_percentile: 99.8 1411 min_percentile: 5.0 1412 ``` 1413 - in Python 1414 >>> preprocessing = [ 1415 ... ScaleRangeDescr( 1416 ... kwargs=ScaleRangeKwargs( 1417 ... axes= (AxisId('y'), AxisId('x')), 1418 ... max_percentile= 99.8, 1419 ... min_percentile= 5.0, 1420 ... ) 1421 ... ), 1422 ... ClipDescr( 1423 ... kwargs=ClipKwargs( 1424 ... min=0.0, 1425 ... max=1.0, 1426 ... ) 1427 ... ), 1428 ... ] 1429 1430 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1431 - in YAML 1432 ```yaml 1433 preprocessing: 1434 - id: scale_range 1435 kwargs: 1436 axes: ['y', 'x'] 1437 max_percentile: 99.8 1438 min_percentile: 5.0 1439 - id: scale_range 1440 - id: clip 1441 kwargs: 1442 min: 0.0 1443 max: 1.0 1444 ``` 1445 - in Python 1446 >>> preprocessing = [ScaleRangeDescr( 1447 ... kwargs=ScaleRangeKwargs( 1448 ... axes= (AxisId('y'), AxisId('x')), 1449 ... max_percentile= 99.8, 1450 ... min_percentile= 5.0, 1451 ... ) 1452 ... )] 1453 1454 """ 1455 1456 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1457 if TYPE_CHECKING: 1458 id: Literal["scale_range"] = "scale_range" 1459 else: 1460 id: Literal["scale_range"] 1461 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
1464class ScaleMeanVarianceKwargs(ProcessingKwargs): 1465 """key word arguments for `ScaleMeanVarianceKwargs`""" 1466 1467 reference_tensor: TensorId 1468 """Name of tensor to match.""" 1469 1470 axes: Annotated[ 1471 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1472 ] = None 1473 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1474 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1475 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1476 To normalize samples independently, leave out the 'batch' axis. 1477 Default: Scale all axes jointly.""" 1478 1479 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1480 """Epsilon for numeric stability: 1481 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y').
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
Epsilon for numeric stability:
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Inherited Members
1484class ScaleMeanVarianceDescr(ProcessingDescrBase): 1485 """Scale a tensor's data distribution to match another tensor's mean/std. 1486 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1487 """ 1488 1489 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1490 if TYPE_CHECKING: 1491 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1492 else: 1493 id: Literal["scale_mean_variance"] 1494 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
1530class TensorDescrBase(Node, Generic[IO_AxisT]): 1531 id: TensorId 1532 """Tensor id. No duplicates are allowed.""" 1533 1534 description: Annotated[str, MaxLen(128)] = "" 1535 """free text description""" 1536 1537 axes: NotEmpty[Sequence[IO_AxisT]] 1538 """tensor axes""" 1539 1540 @property 1541 def shape(self): 1542 return tuple(a.size for a in self.axes) 1543 1544 @field_validator("axes", mode="after", check_fields=False) 1545 @classmethod 1546 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1547 batch_axes = [a for a in axes if a.type == "batch"] 1548 if len(batch_axes) > 1: 1549 raise ValueError( 1550 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1551 ) 1552 1553 seen_ids: Set[AxisId] = set() 1554 duplicate_axes_ids: Set[AxisId] = set() 1555 for a in axes: 1556 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1557 1558 if duplicate_axes_ids: 1559 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1560 1561 return axes 1562 1563 test_tensor: FAIR[Optional[FileDescr_]] = None 1564 """An example tensor to use for testing. 1565 Using the model with the test input tensors is expected to yield the test output tensors. 1566 Each test tensor has be a an ndarray in the 1567 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1568 The file extension must be '.npy'.""" 1569 1570 sample_tensor: FAIR[Optional[FileDescr_]] = None 1571 """A sample tensor to illustrate a possible input/output for the model, 1572 The sample image primarily serves to inform a human user about an example use case 1573 and is typically stored as .hdf5, .png or .tiff. 1574 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1575 (numpy's `.npy` format is not supported). 1576 The image dimensionality has to match the number of axes specified in this tensor description. 1577 """ 1578 1579 @model_validator(mode="after") 1580 def _validate_sample_tensor(self) -> Self: 1581 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1582 return self 1583 1584 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1585 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 1586 reader.read(), 1587 extension=PurePosixPath(reader.original_file_name).suffix, 1588 ) 1589 n_dims = len(tensor.squeeze().shape) 1590 n_dims_min = n_dims_max = len(self.axes) 1591 1592 for a in self.axes: 1593 if isinstance(a, BatchAxis): 1594 n_dims_min -= 1 1595 elif isinstance(a.size, int): 1596 if a.size == 1: 1597 n_dims_min -= 1 1598 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1599 if a.size.min == 1: 1600 n_dims_min -= 1 1601 elif isinstance(a.size, SizeReference): 1602 if a.size.offset < 2: 1603 # size reference may result in singleton axis 1604 n_dims_min -= 1 1605 else: 1606 assert_never(a.size) 1607 1608 n_dims_min = max(0, n_dims_min) 1609 if n_dims < n_dims_min or n_dims > n_dims_max: 1610 raise ValueError( 1611 f"Expected sample tensor to have {n_dims_min} to" 1612 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1613 ) 1614 1615 return self 1616 1617 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1618 IntervalOrRatioDataDescr() 1619 ) 1620 """Description of the tensor's data values, optionally per channel. 1621 If specified per channel, the data `type` needs to match across channels.""" 1622 1623 @property 1624 def dtype( 1625 self, 1626 ) -> Literal[ 1627 "float32", 1628 "float64", 1629 "uint8", 1630 "int8", 1631 "uint16", 1632 "int16", 1633 "uint32", 1634 "int32", 1635 "uint64", 1636 "int64", 1637 "bool", 1638 ]: 1639 """dtype as specified under `data.type` or `data[i].type`""" 1640 if isinstance(self.data, collections.abc.Sequence): 1641 return self.data[0].type 1642 else: 1643 return self.data.type 1644 1645 @field_validator("data", mode="after") 1646 @classmethod 1647 def _check_data_type_across_channels( 1648 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1649 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1650 if not isinstance(value, list): 1651 return value 1652 1653 dtypes = {t.type for t in value} 1654 if len(dtypes) > 1: 1655 raise ValueError( 1656 "Tensor data descriptions per channel need to agree in their data" 1657 + f" `type`, but found {dtypes}." 1658 ) 1659 1660 return value 1661 1662 @model_validator(mode="after") 1663 def _check_data_matches_channelaxis(self) -> Self: 1664 if not isinstance(self.data, (list, tuple)): 1665 return self 1666 1667 for a in self.axes: 1668 if isinstance(a, ChannelAxis): 1669 size = a.size 1670 assert isinstance(size, int) 1671 break 1672 else: 1673 return self 1674 1675 if len(self.data) != size: 1676 raise ValueError( 1677 f"Got tensor data descriptions for {len(self.data)} channels, but" 1678 + f" '{a.id}' axis has size {size}." 1679 ) 1680 1681 return self 1682 1683 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1684 if len(array.shape) != len(self.axes): 1685 raise ValueError( 1686 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1687 + f" incompatible with {len(self.axes)} axes." 1688 ) 1689 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type needs to match across channels.
1623 @property 1624 def dtype( 1625 self, 1626 ) -> Literal[ 1627 "float32", 1628 "float64", 1629 "uint8", 1630 "int8", 1631 "uint16", 1632 "int16", 1633 "uint32", 1634 "int32", 1635 "uint64", 1636 "int64", 1637 "bool", 1638 ]: 1639 """dtype as specified under `data.type` or `data[i].type`""" 1640 if isinstance(self.data, collections.abc.Sequence): 1641 return self.data[0].type 1642 else: 1643 return self.data.type
dtype as specified under data.type or data[i].type
1683 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1684 if len(array.shape) != len(self.axes): 1685 raise ValueError( 1686 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1687 + f" incompatible with {len(self.axes)} axes." 1688 ) 1689 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1692class InputTensorDescr(TensorDescrBase[InputAxis]): 1693 id: TensorId = TensorId("input") 1694 """Input tensor id. 1695 No duplicates are allowed across all inputs and outputs.""" 1696 1697 optional: bool = False 1698 """indicates that this tensor may be `None`""" 1699 1700 preprocessing: List[PreprocessingDescr] = Field( 1701 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1702 ) 1703 1704 """Description of how this input should be preprocessed. 1705 1706 notes: 1707 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1708 to ensure an input tensor's data type matches the input tensor's data description. 1709 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1710 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1711 changing the data type. 1712 """ 1713 1714 @model_validator(mode="after") 1715 def _validate_preprocessing_kwargs(self) -> Self: 1716 axes_ids = [a.id for a in self.axes] 1717 for p in self.preprocessing: 1718 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1719 if kwargs_axes is None: 1720 continue 1721 1722 if not isinstance(kwargs_axes, collections.abc.Sequence): 1723 raise ValueError( 1724 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1725 ) 1726 1727 if any(a not in axes_ids for a in kwargs_axes): 1728 raise ValueError( 1729 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1730 ) 1731 1732 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1733 dtype = self.data.type 1734 else: 1735 dtype = self.data[0].type 1736 1737 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1738 if not self.preprocessing or not isinstance( 1739 self.preprocessing[0], EnsureDtypeDescr 1740 ): 1741 self.preprocessing.insert( 1742 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1743 ) 1744 1745 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1746 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1747 self.preprocessing.append( 1748 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1749 ) 1750 1751 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
Inherited Members
1754def convert_axes( 1755 axes: str, 1756 *, 1757 shape: Union[ 1758 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1759 ], 1760 tensor_type: Literal["input", "output"], 1761 halo: Optional[Sequence[int]], 1762 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1763): 1764 ret: List[AnyAxis] = [] 1765 for i, a in enumerate(axes): 1766 axis_type = _AXIS_TYPE_MAP.get(a, a) 1767 if axis_type == "batch": 1768 ret.append(BatchAxis()) 1769 continue 1770 1771 scale = 1.0 1772 if isinstance(shape, _ParameterizedInputShape_v0_4): 1773 if shape.step[i] == 0: 1774 size = shape.min[i] 1775 else: 1776 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1777 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1778 ref_t = str(shape.reference_tensor) 1779 if ref_t.count(".") == 1: 1780 t_id, orig_a_id = ref_t.split(".") 1781 else: 1782 t_id = ref_t 1783 orig_a_id = a 1784 1785 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1786 if not (orig_scale := shape.scale[i]): 1787 # old way to insert a new axis dimension 1788 size = int(2 * shape.offset[i]) 1789 else: 1790 scale = 1 / orig_scale 1791 if axis_type in ("channel", "index"): 1792 # these axes no longer have a scale 1793 offset_from_scale = orig_scale * size_refs.get( 1794 _TensorName_v0_4(t_id), {} 1795 ).get(orig_a_id, 0) 1796 else: 1797 offset_from_scale = 0 1798 size = SizeReference( 1799 tensor_id=TensorId(t_id), 1800 axis_id=AxisId(a_id), 1801 offset=int(offset_from_scale + 2 * shape.offset[i]), 1802 ) 1803 else: 1804 size = shape[i] 1805 1806 if axis_type == "time": 1807 if tensor_type == "input": 1808 ret.append(TimeInputAxis(size=size, scale=scale)) 1809 else: 1810 assert not isinstance(size, ParameterizedSize) 1811 if halo is None: 1812 ret.append(TimeOutputAxis(size=size, scale=scale)) 1813 else: 1814 assert not isinstance(size, int) 1815 ret.append( 1816 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1817 ) 1818 1819 elif axis_type == "index": 1820 if tensor_type == "input": 1821 ret.append(IndexInputAxis(size=size)) 1822 else: 1823 if isinstance(size, ParameterizedSize): 1824 size = DataDependentSize(min=size.min) 1825 1826 ret.append(IndexOutputAxis(size=size)) 1827 elif axis_type == "channel": 1828 assert not isinstance(size, ParameterizedSize) 1829 if isinstance(size, SizeReference): 1830 warnings.warn( 1831 "Conversion of channel size from an implicit output shape may be" 1832 + " wrong" 1833 ) 1834 ret.append( 1835 ChannelAxis( 1836 channel_names=[ 1837 Identifier(f"channel{i}") for i in range(size.offset) 1838 ] 1839 ) 1840 ) 1841 else: 1842 ret.append( 1843 ChannelAxis( 1844 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1845 ) 1846 ) 1847 elif axis_type == "space": 1848 if tensor_type == "input": 1849 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1850 else: 1851 assert not isinstance(size, ParameterizedSize) 1852 if halo is None or halo[i] == 0: 1853 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1854 elif isinstance(size, int): 1855 raise NotImplementedError( 1856 f"output axis with halo and fixed size (here {size}) not allowed" 1857 ) 1858 else: 1859 ret.append( 1860 SpaceOutputAxisWithHalo( 1861 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1862 ) 1863 ) 1864 1865 return ret
2030class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2031 id: TensorId = TensorId("output") 2032 """Output tensor id. 2033 No duplicates are allowed across all inputs and outputs.""" 2034 2035 postprocessing: List[PostprocessingDescr] = Field( 2036 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2037 ) 2038 """Description of how this output should be postprocessed. 2039 2040 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2041 If not given this is added to cast to this tensor's `data.type`. 2042 """ 2043 2044 @model_validator(mode="after") 2045 def _validate_postprocessing_kwargs(self) -> Self: 2046 axes_ids = [a.id for a in self.axes] 2047 for p in self.postprocessing: 2048 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2049 if kwargs_axes is None: 2050 continue 2051 2052 if not isinstance(kwargs_axes, collections.abc.Sequence): 2053 raise ValueError( 2054 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2055 ) 2056 2057 if any(a not in axes_ids for a in kwargs_axes): 2058 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2059 2060 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2061 dtype = self.data.type 2062 else: 2063 dtype = self.data[0].type 2064 2065 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2066 if not self.postprocessing or not isinstance( 2067 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2068 ): 2069 self.postprocessing.append( 2070 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2071 ) 2072 return self
Description of how this output should be postprocessed.
note: postprocessing always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type.
Inherited Members
2122def validate_tensors( 2123 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2124 tensor_origin: Literal[ 2125 "test_tensor" 2126 ], # for more precise error messages, e.g. 'test_tensor' 2127): 2128 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2129 2130 def e_msg(d: TensorDescr): 2131 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2132 2133 for descr, array in tensors.values(): 2134 if array is None: 2135 axis_sizes = {a.id: None for a in descr.axes} 2136 else: 2137 try: 2138 axis_sizes = descr.get_axis_sizes_for_array(array) 2139 except ValueError as e: 2140 raise ValueError(f"{e_msg(descr)} {e}") 2141 2142 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2143 2144 for descr, array in tensors.values(): 2145 if array is None: 2146 continue 2147 2148 if descr.dtype in ("float32", "float64"): 2149 invalid_test_tensor_dtype = array.dtype.name not in ( 2150 "float32", 2151 "float64", 2152 "uint8", 2153 "int8", 2154 "uint16", 2155 "int16", 2156 "uint32", 2157 "int32", 2158 "uint64", 2159 "int64", 2160 ) 2161 else: 2162 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2163 2164 if invalid_test_tensor_dtype: 2165 raise ValueError( 2166 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2167 + f" match described dtype '{descr.dtype}'" 2168 ) 2169 2170 if array.min() > -1e-4 and array.max() < 1e-4: 2171 raise ValueError( 2172 "Output values are too small for reliable testing." 2173 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2174 ) 2175 2176 for a in descr.axes: 2177 actual_size = all_tensor_axes[descr.id][a.id][1] 2178 if actual_size is None: 2179 continue 2180 2181 if a.size is None: 2182 continue 2183 2184 if isinstance(a.size, int): 2185 if actual_size != a.size: 2186 raise ValueError( 2187 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2188 + f"has incompatible size {actual_size}, expected {a.size}" 2189 ) 2190 elif isinstance(a.size, ParameterizedSize): 2191 _ = a.size.validate_size(actual_size) 2192 elif isinstance(a.size, DataDependentSize): 2193 _ = a.size.validate_size(actual_size) 2194 elif isinstance(a.size, SizeReference): 2195 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2196 if ref_tensor_axes is None: 2197 raise ValueError( 2198 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2199 + f" reference '{a.size.tensor_id}'" 2200 ) 2201 2202 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2203 if ref_axis is None or ref_size is None: 2204 raise ValueError( 2205 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2206 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2207 ) 2208 2209 if a.unit != ref_axis.unit: 2210 raise ValueError( 2211 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2212 + " axis and reference axis to have the same `unit`, but" 2213 + f" {a.unit}!={ref_axis.unit}" 2214 ) 2215 2216 if actual_size != ( 2217 expected_size := ( 2218 ref_size * ref_axis.scale / a.scale + a.size.offset 2219 ) 2220 ): 2221 raise ValueError( 2222 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2223 + f" {actual_size} invalid for referenced size {ref_size};" 2224 + f" expected {expected_size}" 2225 ) 2226 else: 2227 assert_never(a.size)
2247class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2248 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2249 """Architecture source file""" 2250 2251 @model_serializer(mode="wrap", when_used="unless-none") 2252 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2253 return package_file_descr_serializer(self, nxt, info)
A file description
Architecture source file
2256class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2257 import_from: str 2258 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Where to import the callable from, i.e. from <import_from> import <callable>
Inherited Members
2318class WeightsEntryDescrBase(FileDescr): 2319 type: ClassVar[WeightsFormat] 2320 weights_format_name: ClassVar[str] # human readable 2321 2322 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2323 """Source of the weights file.""" 2324 2325 authors: Optional[List[Author]] = None 2326 """Authors 2327 Either the person(s) that have trained this model resulting in the original weights file. 2328 (If this is the initial weights entry, i.e. it does not have a `parent`) 2329 Or the person(s) who have converted the weights to this weights format. 2330 (If this is a child weight, i.e. it has a `parent` field) 2331 """ 2332 2333 parent: Annotated[ 2334 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2335 ] = None 2336 """The source weights these weights were converted from. 2337 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2338 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2339 All weight entries except one (the initial set of weights resulting from training the model), 2340 need to have this field.""" 2341 2342 comment: str = "" 2343 """A comment about this weights entry, for example how these weights were created.""" 2344 2345 @model_validator(mode="after") 2346 def _validate(self) -> Self: 2347 if self.type == self.parent: 2348 raise ValueError("Weights entry can't be it's own parent.") 2349 2350 return self 2351 2352 @model_serializer(mode="wrap", when_used="unless-none") 2353 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2354 return package_file_descr_serializer(self, nxt, info)
A file description
Source of the weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict format to torchscript,
The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
Inherited Members
2357class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2358 type = "keras_hdf5" 2359 weights_format_name: ClassVar[str] = "Keras HDF5" 2360 tensorflow_version: Version 2361 """TensorFlow version used to create these weights."""
A file description
2371class OnnxWeightsDescr(WeightsEntryDescrBase): 2372 type = "onnx" 2373 weights_format_name: ClassVar[str] = "ONNX" 2374 opset_version: Annotated[int, Ge(7)] 2375 """ONNX opset version""" 2376 2377 external_data: Optional[FileDescr_external_data] = None 2378 """Source of the external ONNX data file holding the weights. 2379 (If present **source** holds the ONNX architecture without weights).""" 2380 2381 @model_validator(mode="after") 2382 def _validate_external_data_unique_file_name(self) -> Self: 2383 if self.external_data is not None and ( 2384 extract_file_name(self.source) 2385 == extract_file_name(self.external_data.source) 2386 ): 2387 raise ValueError( 2388 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 2389 + " must be different from ONNX `source` file name." 2390 ) 2391 2392 return self
A file description
Source of the external ONNX data file holding the weights. (If present source holds the ONNX architecture without weights).
2395class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2396 type = "pytorch_state_dict" 2397 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2398 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2399 pytorch_version: Version 2400 """Version of the PyTorch library used. 2401 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2402 """ 2403 dependencies: Optional[FileDescr_dependencies] = None 2404 """Custom depencies beyond pytorch described in a Conda environment file. 2405 Allows to specify custom dependencies, see conda docs: 2406 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2407 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2408 2409 The conda environment file should include pytorch and any version pinning has to be compatible with 2410 **pytorch_version**. 2411 """
A file description
Version of the PyTorch library used.
If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:
The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.
2414class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2415 type = "tensorflow_js" 2416 weights_format_name: ClassVar[str] = "Tensorflow.js" 2417 tensorflow_version: Version 2418 """Version of the TensorFlow library used.""" 2419 2420 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2421 """The multi-file weights. 2422 All required files/folders should be a zip archive."""
A file description
The multi-file weights. All required files/folders should be a zip archive.
2425class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2426 type = "tensorflow_saved_model_bundle" 2427 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2428 tensorflow_version: Version 2429 """Version of the TensorFlow library used.""" 2430 2431 dependencies: Optional[FileDescr_dependencies] = None 2432 """Custom dependencies beyond tensorflow. 2433 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2434 2435 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2436 """The multi-file weights. 2437 All required files/folders should be a zip archive."""
A file description
Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.
The multi-file weights. All required files/folders should be a zip archive.
2440class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2441 type = "torchscript" 2442 weights_format_name: ClassVar[str] = "TorchScript" 2443 pytorch_version: Version 2444 """Version of the PyTorch library used."""
A file description
2447class WeightsDescr(Node): 2448 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2449 onnx: Optional[OnnxWeightsDescr] = None 2450 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2451 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2452 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2453 None 2454 ) 2455 torchscript: Optional[TorchscriptWeightsDescr] = None 2456 2457 @model_validator(mode="after") 2458 def check_entries(self) -> Self: 2459 entries = {wtype for wtype, entry in self if entry is not None} 2460 2461 if not entries: 2462 raise ValueError("Missing weights entry") 2463 2464 entries_wo_parent = { 2465 wtype 2466 for wtype, entry in self 2467 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2468 } 2469 if len(entries_wo_parent) != 1: 2470 issue_warning( 2471 "Exactly one weights entry may not specify the `parent` field (got" 2472 + " {value}). That entry is considered the original set of model weights." 2473 + " Other weight formats are created through conversion of the orignal or" 2474 + " already converted weights. They have to reference the weights format" 2475 + " they were converted from as their `parent`.", 2476 value=len(entries_wo_parent), 2477 field="weights", 2478 ) 2479 2480 for wtype, entry in self: 2481 if entry is None: 2482 continue 2483 2484 assert hasattr(entry, "type") 2485 assert hasattr(entry, "parent") 2486 assert wtype == entry.type 2487 if ( 2488 entry.parent is not None and entry.parent not in entries 2489 ): # self reference checked for `parent` field 2490 raise ValueError( 2491 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2492 + f" formats: {entries}" 2493 ) 2494 2495 return self 2496 2497 def __getitem__( 2498 self, 2499 key: Literal[ 2500 "keras_hdf5", 2501 "onnx", 2502 "pytorch_state_dict", 2503 "tensorflow_js", 2504 "tensorflow_saved_model_bundle", 2505 "torchscript", 2506 ], 2507 ): 2508 if key == "keras_hdf5": 2509 ret = self.keras_hdf5 2510 elif key == "onnx": 2511 ret = self.onnx 2512 elif key == "pytorch_state_dict": 2513 ret = self.pytorch_state_dict 2514 elif key == "tensorflow_js": 2515 ret = self.tensorflow_js 2516 elif key == "tensorflow_saved_model_bundle": 2517 ret = self.tensorflow_saved_model_bundle 2518 elif key == "torchscript": 2519 ret = self.torchscript 2520 else: 2521 raise KeyError(key) 2522 2523 if ret is None: 2524 raise KeyError(key) 2525 2526 return ret 2527 2528 @property 2529 def available_formats(self): 2530 return { 2531 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2532 **({} if self.onnx is None else {"onnx": self.onnx}), 2533 **( 2534 {} 2535 if self.pytorch_state_dict is None 2536 else {"pytorch_state_dict": self.pytorch_state_dict} 2537 ), 2538 **( 2539 {} 2540 if self.tensorflow_js is None 2541 else {"tensorflow_js": self.tensorflow_js} 2542 ), 2543 **( 2544 {} 2545 if self.tensorflow_saved_model_bundle is None 2546 else { 2547 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2548 } 2549 ), 2550 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2551 } 2552 2553 @property 2554 def missing_formats(self): 2555 return { 2556 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2557 }
2457 @model_validator(mode="after") 2458 def check_entries(self) -> Self: 2459 entries = {wtype for wtype, entry in self if entry is not None} 2460 2461 if not entries: 2462 raise ValueError("Missing weights entry") 2463 2464 entries_wo_parent = { 2465 wtype 2466 for wtype, entry in self 2467 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2468 } 2469 if len(entries_wo_parent) != 1: 2470 issue_warning( 2471 "Exactly one weights entry may not specify the `parent` field (got" 2472 + " {value}). That entry is considered the original set of model weights." 2473 + " Other weight formats are created through conversion of the orignal or" 2474 + " already converted weights. They have to reference the weights format" 2475 + " they were converted from as their `parent`.", 2476 value=len(entries_wo_parent), 2477 field="weights", 2478 ) 2479 2480 for wtype, entry in self: 2481 if entry is None: 2482 continue 2483 2484 assert hasattr(entry, "type") 2485 assert hasattr(entry, "parent") 2486 assert wtype == entry.type 2487 if ( 2488 entry.parent is not None and entry.parent not in entries 2489 ): # self reference checked for `parent` field 2490 raise ValueError( 2491 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2492 + f" formats: {entries}" 2493 ) 2494 2495 return self
2528 @property 2529 def available_formats(self): 2530 return { 2531 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2532 **({} if self.onnx is None else {"onnx": self.onnx}), 2533 **( 2534 {} 2535 if self.pytorch_state_dict is None 2536 else {"pytorch_state_dict": self.pytorch_state_dict} 2537 ), 2538 **( 2539 {} 2540 if self.tensorflow_js is None 2541 else {"tensorflow_js": self.tensorflow_js} 2542 ), 2543 **( 2544 {} 2545 if self.tensorflow_saved_model_bundle is None 2546 else { 2547 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2548 } 2549 ), 2550 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2551 }
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2564class LinkedModel(LinkedResourceBase): 2565 """Reference to a bioimage.io model.""" 2566 2567 id: ModelId 2568 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Inherited Members
2590class ReproducibilityTolerance(Node, extra="allow"): 2591 """Describes what small numerical differences -- if any -- may be tolerated 2592 in the generated output when executing in different environments. 2593 2594 A tensor element *output* is considered mismatched to the **test_tensor** if 2595 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2596 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2597 2598 Motivation: 2599 For testing we can request the respective deep learning frameworks to be as 2600 reproducible as possible by setting seeds and chosing deterministic algorithms, 2601 but differences in operating systems, available hardware and installed drivers 2602 may still lead to numerical differences. 2603 """ 2604 2605 relative_tolerance: RelativeTolerance = 1e-3 2606 """Maximum relative tolerance of reproduced test tensor.""" 2607 2608 absolute_tolerance: AbsoluteTolerance = 1e-4 2609 """Maximum absolute tolerance of reproduced test tensor.""" 2610 2611 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2612 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2613 2614 output_ids: Sequence[TensorId] = () 2615 """Limits the output tensor IDs these reproducibility details apply to.""" 2616 2617 weights_formats: Sequence[WeightsFormat] = () 2618 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
Maximum number of mismatched elements/pixels per million to tolerate.
2621class BioimageioConfig(Node, extra="allow"): 2622 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2623 """Tolerances to allow when reproducing the model's test outputs 2624 from the model's test inputs. 2625 Only the first entry matching tensor id and weights format is considered. 2626 """
2629class Config(Node, extra="allow"): 2630 bioimageio: BioimageioConfig = Field( 2631 default_factory=BioimageioConfig.model_construct 2632 )
2635class ModelDescr(GenericModelDescrBase): 2636 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2637 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2638 """ 2639 2640 implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6" 2641 if TYPE_CHECKING: 2642 format_version: Literal["0.5.6"] = "0.5.6" 2643 else: 2644 format_version: Literal["0.5.6"] 2645 """Version of the bioimage.io model description specification used. 2646 When creating a new model always use the latest micro/patch version described here. 2647 The `format_version` is important for any consumer software to understand how to parse the fields. 2648 """ 2649 2650 implemented_type: ClassVar[Literal["model"]] = "model" 2651 if TYPE_CHECKING: 2652 type: Literal["model"] = "model" 2653 else: 2654 type: Literal["model"] 2655 """Specialized resource type 'model'""" 2656 2657 id: Optional[ModelId] = None 2658 """bioimage.io-wide unique resource identifier 2659 assigned by bioimage.io; version **un**specific.""" 2660 2661 authors: FAIR[List[Author]] = Field( 2662 default_factory=cast(Callable[[], List[Author]], list) 2663 ) 2664 """The authors are the creators of the model RDF and the primary points of contact.""" 2665 2666 documentation: FAIR[Optional[FileSource_documentation]] = None 2667 """URL or relative path to a markdown file with additional documentation. 2668 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2669 The documentation should include a '#[#] Validation' (sub)section 2670 with details on how to quantitatively validate the model on unseen data.""" 2671 2672 @field_validator("documentation", mode="after") 2673 @classmethod 2674 def _validate_documentation( 2675 cls, value: Optional[FileSource_documentation] 2676 ) -> Optional[FileSource_documentation]: 2677 if not get_validation_context().perform_io_checks or value is None: 2678 return value 2679 2680 doc_reader = get_reader(value) 2681 doc_content = doc_reader.read().decode(encoding="utf-8") 2682 if not re.search("#.*[vV]alidation", doc_content): 2683 issue_warning( 2684 "No '# Validation' (sub)section found in {value}.", 2685 value=value, 2686 field="documentation", 2687 ) 2688 2689 return value 2690 2691 inputs: NotEmpty[Sequence[InputTensorDescr]] 2692 """Describes the input tensors expected by this model.""" 2693 2694 @field_validator("inputs", mode="after") 2695 @classmethod 2696 def _validate_input_axes( 2697 cls, inputs: Sequence[InputTensorDescr] 2698 ) -> Sequence[InputTensorDescr]: 2699 input_size_refs = cls._get_axes_with_independent_size(inputs) 2700 2701 for i, ipt in enumerate(inputs): 2702 valid_independent_refs: Dict[ 2703 Tuple[TensorId, AxisId], 2704 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2705 ] = { 2706 **{ 2707 (ipt.id, a.id): (ipt, a, a.size) 2708 for a in ipt.axes 2709 if not isinstance(a, BatchAxis) 2710 and isinstance(a.size, (int, ParameterizedSize)) 2711 }, 2712 **input_size_refs, 2713 } 2714 for a, ax in enumerate(ipt.axes): 2715 cls._validate_axis( 2716 "inputs", 2717 i=i, 2718 tensor_id=ipt.id, 2719 a=a, 2720 axis=ax, 2721 valid_independent_refs=valid_independent_refs, 2722 ) 2723 return inputs 2724 2725 @staticmethod 2726 def _validate_axis( 2727 field_name: str, 2728 i: int, 2729 tensor_id: TensorId, 2730 a: int, 2731 axis: AnyAxis, 2732 valid_independent_refs: Dict[ 2733 Tuple[TensorId, AxisId], 2734 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2735 ], 2736 ): 2737 if isinstance(axis, BatchAxis) or isinstance( 2738 axis.size, (int, ParameterizedSize, DataDependentSize) 2739 ): 2740 return 2741 elif not isinstance(axis.size, SizeReference): 2742 assert_never(axis.size) 2743 2744 # validate axis.size SizeReference 2745 ref = (axis.size.tensor_id, axis.size.axis_id) 2746 if ref not in valid_independent_refs: 2747 raise ValueError( 2748 "Invalid tensor axis reference at" 2749 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2750 ) 2751 if ref == (tensor_id, axis.id): 2752 raise ValueError( 2753 "Self-referencing not allowed for" 2754 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2755 ) 2756 if axis.type == "channel": 2757 if valid_independent_refs[ref][1].type != "channel": 2758 raise ValueError( 2759 "A channel axis' size may only reference another fixed size" 2760 + " channel axis." 2761 ) 2762 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2763 ref_size = valid_independent_refs[ref][2] 2764 assert isinstance(ref_size, int), ( 2765 "channel axis ref (another channel axis) has to specify fixed" 2766 + " size" 2767 ) 2768 generated_channel_names = [ 2769 Identifier(axis.channel_names.format(i=i)) 2770 for i in range(1, ref_size + 1) 2771 ] 2772 axis.channel_names = generated_channel_names 2773 2774 if (ax_unit := getattr(axis, "unit", None)) != ( 2775 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2776 ): 2777 raise ValueError( 2778 "The units of an axis and its reference axis need to match, but" 2779 + f" '{ax_unit}' != '{ref_unit}'." 2780 ) 2781 ref_axis = valid_independent_refs[ref][1] 2782 if isinstance(ref_axis, BatchAxis): 2783 raise ValueError( 2784 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2785 + " (a batch axis is not allowed as reference)." 2786 ) 2787 2788 if isinstance(axis, WithHalo): 2789 min_size = axis.size.get_size(axis, ref_axis, n=0) 2790 if (min_size - 2 * axis.halo) < 1: 2791 raise ValueError( 2792 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2793 + f" {axis.halo}." 2794 ) 2795 2796 input_halo = axis.halo * axis.scale / ref_axis.scale 2797 if input_halo != int(input_halo) or input_halo % 2 == 1: 2798 raise ValueError( 2799 f"input_halo {input_halo} (output_halo {axis.halo} *" 2800 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2801 + f" {tensor_id}.{axis.id}." 2802 ) 2803 2804 @model_validator(mode="after") 2805 def _validate_test_tensors(self) -> Self: 2806 if not get_validation_context().perform_io_checks: 2807 return self 2808 2809 test_output_arrays = [ 2810 None if descr.test_tensor is None else load_array(descr.test_tensor) 2811 for descr in self.outputs 2812 ] 2813 test_input_arrays = [ 2814 None if descr.test_tensor is None else load_array(descr.test_tensor) 2815 for descr in self.inputs 2816 ] 2817 2818 tensors = { 2819 descr.id: (descr, array) 2820 for descr, array in zip( 2821 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2822 ) 2823 } 2824 validate_tensors(tensors, tensor_origin="test_tensor") 2825 2826 output_arrays = { 2827 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2828 } 2829 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2830 if not rep_tol.absolute_tolerance: 2831 continue 2832 2833 if rep_tol.output_ids: 2834 out_arrays = { 2835 oid: a 2836 for oid, a in output_arrays.items() 2837 if oid in rep_tol.output_ids 2838 } 2839 else: 2840 out_arrays = output_arrays 2841 2842 for out_id, array in out_arrays.items(): 2843 if array is None: 2844 continue 2845 2846 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2847 raise ValueError( 2848 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2849 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2850 + f" (1% of the maximum value of the test tensor '{out_id}')" 2851 ) 2852 2853 return self 2854 2855 @model_validator(mode="after") 2856 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2857 ipt_refs = {t.id for t in self.inputs} 2858 out_refs = {t.id for t in self.outputs} 2859 for ipt in self.inputs: 2860 for p in ipt.preprocessing: 2861 ref = p.kwargs.get("reference_tensor") 2862 if ref is None: 2863 continue 2864 if ref not in ipt_refs: 2865 raise ValueError( 2866 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2867 + f" references are: {ipt_refs}." 2868 ) 2869 2870 for out in self.outputs: 2871 for p in out.postprocessing: 2872 ref = p.kwargs.get("reference_tensor") 2873 if ref is None: 2874 continue 2875 2876 if ref not in ipt_refs and ref not in out_refs: 2877 raise ValueError( 2878 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2879 + f" are: {ipt_refs | out_refs}." 2880 ) 2881 2882 return self 2883 2884 # TODO: use validate funcs in validate_test_tensors 2885 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2886 2887 name: Annotated[ 2888 str, 2889 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2890 MinLen(5), 2891 MaxLen(128), 2892 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2893 ] 2894 """A human-readable name of this model. 2895 It should be no longer than 64 characters 2896 and may only contain letter, number, underscore, minus, parentheses and spaces. 2897 We recommend to chose a name that refers to the model's task and image modality. 2898 """ 2899 2900 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2901 """Describes the output tensors.""" 2902 2903 @field_validator("outputs", mode="after") 2904 @classmethod 2905 def _validate_tensor_ids( 2906 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2907 ) -> Sequence[OutputTensorDescr]: 2908 tensor_ids = [ 2909 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2910 ] 2911 duplicate_tensor_ids: List[str] = [] 2912 seen: Set[str] = set() 2913 for t in tensor_ids: 2914 if t in seen: 2915 duplicate_tensor_ids.append(t) 2916 2917 seen.add(t) 2918 2919 if duplicate_tensor_ids: 2920 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2921 2922 return outputs 2923 2924 @staticmethod 2925 def _get_axes_with_parameterized_size( 2926 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2927 ): 2928 return { 2929 f"{t.id}.{a.id}": (t, a, a.size) 2930 for t in io 2931 for a in t.axes 2932 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2933 } 2934 2935 @staticmethod 2936 def _get_axes_with_independent_size( 2937 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2938 ): 2939 return { 2940 (t.id, a.id): (t, a, a.size) 2941 for t in io 2942 for a in t.axes 2943 if not isinstance(a, BatchAxis) 2944 and isinstance(a.size, (int, ParameterizedSize)) 2945 } 2946 2947 @field_validator("outputs", mode="after") 2948 @classmethod 2949 def _validate_output_axes( 2950 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2951 ) -> List[OutputTensorDescr]: 2952 input_size_refs = cls._get_axes_with_independent_size( 2953 info.data.get("inputs", []) 2954 ) 2955 output_size_refs = cls._get_axes_with_independent_size(outputs) 2956 2957 for i, out in enumerate(outputs): 2958 valid_independent_refs: Dict[ 2959 Tuple[TensorId, AxisId], 2960 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2961 ] = { 2962 **{ 2963 (out.id, a.id): (out, a, a.size) 2964 for a in out.axes 2965 if not isinstance(a, BatchAxis) 2966 and isinstance(a.size, (int, ParameterizedSize)) 2967 }, 2968 **input_size_refs, 2969 **output_size_refs, 2970 } 2971 for a, ax in enumerate(out.axes): 2972 cls._validate_axis( 2973 "outputs", 2974 i, 2975 out.id, 2976 a, 2977 ax, 2978 valid_independent_refs=valid_independent_refs, 2979 ) 2980 2981 return outputs 2982 2983 packaged_by: List[Author] = Field( 2984 default_factory=cast(Callable[[], List[Author]], list) 2985 ) 2986 """The persons that have packaged and uploaded this model. 2987 Only required if those persons differ from the `authors`.""" 2988 2989 parent: Optional[LinkedModel] = None 2990 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2991 2992 @model_validator(mode="after") 2993 def _validate_parent_is_not_self(self) -> Self: 2994 if self.parent is not None and self.parent.id == self.id: 2995 raise ValueError("A model description may not reference itself as parent.") 2996 2997 return self 2998 2999 run_mode: Annotated[ 3000 Optional[RunMode], 3001 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 3002 ] = None 3003 """Custom run mode for this model: for more complex prediction procedures like test time 3004 data augmentation that currently cannot be expressed in the specification. 3005 No standard run modes are defined yet.""" 3006 3007 timestamp: Datetime = Field(default_factory=Datetime.now) 3008 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 3009 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 3010 (In Python a datetime object is valid, too).""" 3011 3012 training_data: Annotated[ 3013 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 3014 Field(union_mode="left_to_right"), 3015 ] = None 3016 """The dataset used to train this model""" 3017 3018 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 3019 """The weights for this model. 3020 Weights can be given for different formats, but should otherwise be equivalent. 3021 The available weight formats determine which consumers can use this model.""" 3022 3023 config: Config = Field(default_factory=Config.model_construct) 3024 3025 @model_validator(mode="after") 3026 def _add_default_cover(self) -> Self: 3027 if not get_validation_context().perform_io_checks or self.covers: 3028 return self 3029 3030 try: 3031 generated_covers = generate_covers( 3032 [ 3033 (t, load_array(t.test_tensor)) 3034 for t in self.inputs 3035 if t.test_tensor is not None 3036 ], 3037 [ 3038 (t, load_array(t.test_tensor)) 3039 for t in self.outputs 3040 if t.test_tensor is not None 3041 ], 3042 ) 3043 except Exception as e: 3044 issue_warning( 3045 "Failed to generate cover image(s): {e}", 3046 value=self.covers, 3047 msg_context=dict(e=e), 3048 field="covers", 3049 ) 3050 else: 3051 self.covers.extend(generated_covers) 3052 3053 return self 3054 3055 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3056 return self._get_test_arrays(self.inputs) 3057 3058 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3059 return self._get_test_arrays(self.outputs) 3060 3061 @staticmethod 3062 def _get_test_arrays( 3063 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3064 ): 3065 ts: List[FileDescr] = [] 3066 for d in io_descr: 3067 if d.test_tensor is None: 3068 raise ValueError( 3069 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3070 ) 3071 ts.append(d.test_tensor) 3072 3073 data = [load_array(t) for t in ts] 3074 assert all(isinstance(d, np.ndarray) for d in data) 3075 return data 3076 3077 @staticmethod 3078 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3079 batch_size = 1 3080 tensor_with_batchsize: Optional[TensorId] = None 3081 for tid in tensor_sizes: 3082 for aid, s in tensor_sizes[tid].items(): 3083 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3084 continue 3085 3086 if batch_size != 1: 3087 assert tensor_with_batchsize is not None 3088 raise ValueError( 3089 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3090 ) 3091 3092 batch_size = s 3093 tensor_with_batchsize = tid 3094 3095 return batch_size 3096 3097 def get_output_tensor_sizes( 3098 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3099 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3100 """Returns the tensor output sizes for given **input_sizes**. 3101 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3102 Otherwise it might be larger than the actual (valid) output""" 3103 batch_size = self.get_batch_size(input_sizes) 3104 ns = self.get_ns(input_sizes) 3105 3106 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3107 return tensor_sizes.outputs 3108 3109 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3110 """get parameter `n` for each parameterized axis 3111 such that the valid input size is >= the given input size""" 3112 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3113 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3114 for tid in input_sizes: 3115 for aid, s in input_sizes[tid].items(): 3116 size_descr = axes[tid][aid].size 3117 if isinstance(size_descr, ParameterizedSize): 3118 ret[(tid, aid)] = size_descr.get_n(s) 3119 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3120 pass 3121 else: 3122 assert_never(size_descr) 3123 3124 return ret 3125 3126 def get_tensor_sizes( 3127 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3128 ) -> _TensorSizes: 3129 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3130 return _TensorSizes( 3131 { 3132 t: { 3133 aa: axis_sizes.inputs[(tt, aa)] 3134 for tt, aa in axis_sizes.inputs 3135 if tt == t 3136 } 3137 for t in {tt for tt, _ in axis_sizes.inputs} 3138 }, 3139 { 3140 t: { 3141 aa: axis_sizes.outputs[(tt, aa)] 3142 for tt, aa in axis_sizes.outputs 3143 if tt == t 3144 } 3145 for t in {tt for tt, _ in axis_sizes.outputs} 3146 }, 3147 ) 3148 3149 def get_axis_sizes( 3150 self, 3151 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3152 batch_size: Optional[int] = None, 3153 *, 3154 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3155 ) -> _AxisSizes: 3156 """Determine input and output block shape for scale factors **ns** 3157 of parameterized input sizes. 3158 3159 Args: 3160 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3161 that is parameterized as `size = min + n * step`. 3162 batch_size: The desired size of the batch dimension. 3163 If given **batch_size** overwrites any batch size present in 3164 **max_input_shape**. Default 1. 3165 max_input_shape: Limits the derived block shapes. 3166 Each axis for which the input size, parameterized by `n`, is larger 3167 than **max_input_shape** is set to the minimal value `n_min` for which 3168 this is still true. 3169 Use this for small input samples or large values of **ns**. 3170 Or simply whenever you know the full input shape. 3171 3172 Returns: 3173 Resolved axis sizes for model inputs and outputs. 3174 """ 3175 max_input_shape = max_input_shape or {} 3176 if batch_size is None: 3177 for (_t_id, a_id), s in max_input_shape.items(): 3178 if a_id == BATCH_AXIS_ID: 3179 batch_size = s 3180 break 3181 else: 3182 batch_size = 1 3183 3184 all_axes = { 3185 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3186 } 3187 3188 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3189 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3190 3191 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3192 if isinstance(a, BatchAxis): 3193 if (t_descr.id, a.id) in ns: 3194 logger.warning( 3195 "Ignoring unexpected size increment factor (n) for batch axis" 3196 + " of tensor '{}'.", 3197 t_descr.id, 3198 ) 3199 return batch_size 3200 elif isinstance(a.size, int): 3201 if (t_descr.id, a.id) in ns: 3202 logger.warning( 3203 "Ignoring unexpected size increment factor (n) for fixed size" 3204 + " axis '{}' of tensor '{}'.", 3205 a.id, 3206 t_descr.id, 3207 ) 3208 return a.size 3209 elif isinstance(a.size, ParameterizedSize): 3210 if (t_descr.id, a.id) not in ns: 3211 raise ValueError( 3212 "Size increment factor (n) missing for parametrized axis" 3213 + f" '{a.id}' of tensor '{t_descr.id}'." 3214 ) 3215 n = ns[(t_descr.id, a.id)] 3216 s_max = max_input_shape.get((t_descr.id, a.id)) 3217 if s_max is not None: 3218 n = min(n, a.size.get_n(s_max)) 3219 3220 return a.size.get_size(n) 3221 3222 elif isinstance(a.size, SizeReference): 3223 if (t_descr.id, a.id) in ns: 3224 logger.warning( 3225 "Ignoring unexpected size increment factor (n) for axis '{}'" 3226 + " of tensor '{}' with size reference.", 3227 a.id, 3228 t_descr.id, 3229 ) 3230 assert not isinstance(a, BatchAxis) 3231 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3232 assert not isinstance(ref_axis, BatchAxis) 3233 ref_key = (a.size.tensor_id, a.size.axis_id) 3234 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3235 assert ref_size is not None, ref_key 3236 assert not isinstance(ref_size, _DataDepSize), ref_key 3237 return a.size.get_size( 3238 axis=a, 3239 ref_axis=ref_axis, 3240 ref_size=ref_size, 3241 ) 3242 elif isinstance(a.size, DataDependentSize): 3243 if (t_descr.id, a.id) in ns: 3244 logger.warning( 3245 "Ignoring unexpected increment factor (n) for data dependent" 3246 + " size axis '{}' of tensor '{}'.", 3247 a.id, 3248 t_descr.id, 3249 ) 3250 return _DataDepSize(a.size.min, a.size.max) 3251 else: 3252 assert_never(a.size) 3253 3254 # first resolve all , but the `SizeReference` input sizes 3255 for t_descr in self.inputs: 3256 for a in t_descr.axes: 3257 if not isinstance(a.size, SizeReference): 3258 s = get_axis_size(a) 3259 assert not isinstance(s, _DataDepSize) 3260 inputs[t_descr.id, a.id] = s 3261 3262 # resolve all other input axis sizes 3263 for t_descr in self.inputs: 3264 for a in t_descr.axes: 3265 if isinstance(a.size, SizeReference): 3266 s = get_axis_size(a) 3267 assert not isinstance(s, _DataDepSize) 3268 inputs[t_descr.id, a.id] = s 3269 3270 # resolve all output axis sizes 3271 for t_descr in self.outputs: 3272 for a in t_descr.axes: 3273 assert not isinstance(a.size, ParameterizedSize) 3274 s = get_axis_size(a) 3275 outputs[t_descr.id, a.id] = s 3276 3277 return _AxisSizes(inputs=inputs, outputs=outputs) 3278 3279 @model_validator(mode="before") 3280 @classmethod 3281 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3282 cls.convert_from_old_format_wo_validation(data) 3283 return data 3284 3285 @classmethod 3286 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3287 """Convert metadata following an older format version to this classes' format 3288 without validating the result. 3289 """ 3290 if ( 3291 data.get("type") == "model" 3292 and isinstance(fv := data.get("format_version"), str) 3293 and fv.count(".") == 2 3294 ): 3295 fv_parts = fv.split(".") 3296 if any(not p.isdigit() for p in fv_parts): 3297 return 3298 3299 fv_tuple = tuple(map(int, fv_parts)) 3300 3301 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3302 if fv_tuple[:2] in ((0, 3), (0, 4)): 3303 m04 = _ModelDescr_v0_4.load(data) 3304 if isinstance(m04, InvalidDescr): 3305 try: 3306 updated = _model_conv.convert_as_dict( 3307 m04 # pyright: ignore[reportArgumentType] 3308 ) 3309 except Exception as e: 3310 logger.error( 3311 "Failed to convert from invalid model 0.4 description." 3312 + f"\nerror: {e}" 3313 + "\nProceeding with model 0.5 validation without conversion." 3314 ) 3315 updated = None 3316 else: 3317 updated = _model_conv.convert_as_dict(m04) 3318 3319 if updated is not None: 3320 data.clear() 3321 data.update(updated) 3322 3323 elif fv_tuple[:2] == (0, 5): 3324 # bump patch version 3325 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md. An .md suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
3077 @staticmethod 3078 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3079 batch_size = 1 3080 tensor_with_batchsize: Optional[TensorId] = None 3081 for tid in tensor_sizes: 3082 for aid, s in tensor_sizes[tid].items(): 3083 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3084 continue 3085 3086 if batch_size != 1: 3087 assert tensor_with_batchsize is not None 3088 raise ValueError( 3089 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3090 ) 3091 3092 batch_size = s 3093 tensor_with_batchsize = tid 3094 3095 return batch_size
3097 def get_output_tensor_sizes( 3098 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3099 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3100 """Returns the tensor output sizes for given **input_sizes**. 3101 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3102 Otherwise it might be larger than the actual (valid) output""" 3103 batch_size = self.get_batch_size(input_sizes) 3104 ns = self.get_ns(input_sizes) 3105 3106 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3107 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
3109 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3110 """get parameter `n` for each parameterized axis 3111 such that the valid input size is >= the given input size""" 3112 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3113 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3114 for tid in input_sizes: 3115 for aid, s in input_sizes[tid].items(): 3116 size_descr = axes[tid][aid].size 3117 if isinstance(size_descr, ParameterizedSize): 3118 ret[(tid, aid)] = size_descr.get_n(s) 3119 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3120 pass 3121 else: 3122 assert_never(size_descr) 3123 3124 return ret
get parameter n for each parameterized axis
such that the valid input size is >= the given input size
3126 def get_tensor_sizes( 3127 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3128 ) -> _TensorSizes: 3129 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3130 return _TensorSizes( 3131 { 3132 t: { 3133 aa: axis_sizes.inputs[(tt, aa)] 3134 for tt, aa in axis_sizes.inputs 3135 if tt == t 3136 } 3137 for t in {tt for tt, _ in axis_sizes.inputs} 3138 }, 3139 { 3140 t: { 3141 aa: axis_sizes.outputs[(tt, aa)] 3142 for tt, aa in axis_sizes.outputs 3143 if tt == t 3144 } 3145 for t in {tt for tt, _ in axis_sizes.outputs} 3146 }, 3147 )
3149 def get_axis_sizes( 3150 self, 3151 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3152 batch_size: Optional[int] = None, 3153 *, 3154 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3155 ) -> _AxisSizes: 3156 """Determine input and output block shape for scale factors **ns** 3157 of parameterized input sizes. 3158 3159 Args: 3160 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3161 that is parameterized as `size = min + n * step`. 3162 batch_size: The desired size of the batch dimension. 3163 If given **batch_size** overwrites any batch size present in 3164 **max_input_shape**. Default 1. 3165 max_input_shape: Limits the derived block shapes. 3166 Each axis for which the input size, parameterized by `n`, is larger 3167 than **max_input_shape** is set to the minimal value `n_min` for which 3168 this is still true. 3169 Use this for small input samples or large values of **ns**. 3170 Or simply whenever you know the full input shape. 3171 3172 Returns: 3173 Resolved axis sizes for model inputs and outputs. 3174 """ 3175 max_input_shape = max_input_shape or {} 3176 if batch_size is None: 3177 for (_t_id, a_id), s in max_input_shape.items(): 3178 if a_id == BATCH_AXIS_ID: 3179 batch_size = s 3180 break 3181 else: 3182 batch_size = 1 3183 3184 all_axes = { 3185 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3186 } 3187 3188 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3189 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3190 3191 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3192 if isinstance(a, BatchAxis): 3193 if (t_descr.id, a.id) in ns: 3194 logger.warning( 3195 "Ignoring unexpected size increment factor (n) for batch axis" 3196 + " of tensor '{}'.", 3197 t_descr.id, 3198 ) 3199 return batch_size 3200 elif isinstance(a.size, int): 3201 if (t_descr.id, a.id) in ns: 3202 logger.warning( 3203 "Ignoring unexpected size increment factor (n) for fixed size" 3204 + " axis '{}' of tensor '{}'.", 3205 a.id, 3206 t_descr.id, 3207 ) 3208 return a.size 3209 elif isinstance(a.size, ParameterizedSize): 3210 if (t_descr.id, a.id) not in ns: 3211 raise ValueError( 3212 "Size increment factor (n) missing for parametrized axis" 3213 + f" '{a.id}' of tensor '{t_descr.id}'." 3214 ) 3215 n = ns[(t_descr.id, a.id)] 3216 s_max = max_input_shape.get((t_descr.id, a.id)) 3217 if s_max is not None: 3218 n = min(n, a.size.get_n(s_max)) 3219 3220 return a.size.get_size(n) 3221 3222 elif isinstance(a.size, SizeReference): 3223 if (t_descr.id, a.id) in ns: 3224 logger.warning( 3225 "Ignoring unexpected size increment factor (n) for axis '{}'" 3226 + " of tensor '{}' with size reference.", 3227 a.id, 3228 t_descr.id, 3229 ) 3230 assert not isinstance(a, BatchAxis) 3231 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3232 assert not isinstance(ref_axis, BatchAxis) 3233 ref_key = (a.size.tensor_id, a.size.axis_id) 3234 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3235 assert ref_size is not None, ref_key 3236 assert not isinstance(ref_size, _DataDepSize), ref_key 3237 return a.size.get_size( 3238 axis=a, 3239 ref_axis=ref_axis, 3240 ref_size=ref_size, 3241 ) 3242 elif isinstance(a.size, DataDependentSize): 3243 if (t_descr.id, a.id) in ns: 3244 logger.warning( 3245 "Ignoring unexpected increment factor (n) for data dependent" 3246 + " size axis '{}' of tensor '{}'.", 3247 a.id, 3248 t_descr.id, 3249 ) 3250 return _DataDepSize(a.size.min, a.size.max) 3251 else: 3252 assert_never(a.size) 3253 3254 # first resolve all , but the `SizeReference` input sizes 3255 for t_descr in self.inputs: 3256 for a in t_descr.axes: 3257 if not isinstance(a.size, SizeReference): 3258 s = get_axis_size(a) 3259 assert not isinstance(s, _DataDepSize) 3260 inputs[t_descr.id, a.id] = s 3261 3262 # resolve all other input axis sizes 3263 for t_descr in self.inputs: 3264 for a in t_descr.axes: 3265 if isinstance(a.size, SizeReference): 3266 s = get_axis_size(a) 3267 assert not isinstance(s, _DataDepSize) 3268 inputs[t_descr.id, a.id] = s 3269 3270 # resolve all output axis sizes 3271 for t_descr in self.outputs: 3272 for a in t_descr.axes: 3273 assert not isinstance(a.size, ParameterizedSize) 3274 s = get_axis_size(a) 3275 outputs[t_descr.id, a.id] = s 3276 3277 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
nfor each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n, is larger than max_input_shape is set to the minimal valuen_minfor which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3285 @classmethod 3286 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3287 """Convert metadata following an older format version to this classes' format 3288 without validating the result. 3289 """ 3290 if ( 3291 data.get("type") == "model" 3292 and isinstance(fv := data.get("format_version"), str) 3293 and fv.count(".") == 2 3294 ): 3295 fv_parts = fv.split(".") 3296 if any(not p.isdigit() for p in fv_parts): 3297 return 3298 3299 fv_tuple = tuple(map(int, fv_parts)) 3300 3301 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3302 if fv_tuple[:2] in ((0, 3), (0, 4)): 3303 m04 = _ModelDescr_v0_4.load(data) 3304 if isinstance(m04, InvalidDescr): 3305 try: 3306 updated = _model_conv.convert_as_dict( 3307 m04 # pyright: ignore[reportArgumentType] 3308 ) 3309 except Exception as e: 3310 logger.error( 3311 "Failed to convert from invalid model 0.4 description." 3312 + f"\nerror: {e}" 3313 + "\nProceeding with model 0.5 validation without conversion." 3314 ) 3315 updated = None 3316 else: 3317 updated = _model_conv.convert_as_dict(m04) 3318 3319 if updated is not None: 3320 data.clear() 3321 data.update(updated) 3322 3323 elif fv_tuple[:2] == (0, 5): 3324 # bump patch version 3325 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Inherited Members
3542def generate_covers( 3543 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3544 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3545) -> List[Path]: 3546 def squeeze( 3547 data: NDArray[Any], axes: Sequence[AnyAxis] 3548 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3549 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3550 if data.ndim != len(axes): 3551 raise ValueError( 3552 f"tensor shape {data.shape} does not match described axes" 3553 + f" {[a.id for a in axes]}" 3554 ) 3555 3556 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3557 return data.squeeze(), axes 3558 3559 def normalize( 3560 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3561 ) -> NDArray[np.float32]: 3562 data = data.astype("float32") 3563 data -= data.min(axis=axis, keepdims=True) 3564 data /= data.max(axis=axis, keepdims=True) + eps 3565 return data 3566 3567 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3568 original_shape = data.shape 3569 original_axes = list(axes) 3570 data, axes = squeeze(data, axes) 3571 3572 # take slice fom any batch or index axis if needed 3573 # and convert the first channel axis and take a slice from any additional channel axes 3574 slices: Tuple[slice, ...] = () 3575 ndim = data.ndim 3576 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3577 has_c_axis = False 3578 for i, a in enumerate(axes): 3579 s = data.shape[i] 3580 assert s > 1 3581 if ( 3582 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3583 and ndim > ndim_need 3584 ): 3585 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3586 ndim -= 1 3587 elif isinstance(a, ChannelAxis): 3588 if has_c_axis: 3589 # second channel axis 3590 data = data[slices + (slice(0, 1),)] 3591 ndim -= 1 3592 else: 3593 has_c_axis = True 3594 if s == 2: 3595 # visualize two channels with cyan and magenta 3596 data = np.concatenate( 3597 [ 3598 data[slices + (slice(1, 2),)], 3599 data[slices + (slice(0, 1),)], 3600 ( 3601 data[slices + (slice(0, 1),)] 3602 + data[slices + (slice(1, 2),)] 3603 ) 3604 / 2, # TODO: take maximum instead? 3605 ], 3606 axis=i, 3607 ) 3608 elif data.shape[i] == 3: 3609 pass # visualize 3 channels as RGB 3610 else: 3611 # visualize first 3 channels as RGB 3612 data = data[slices + (slice(3),)] 3613 3614 assert data.shape[i] == 3 3615 3616 slices += (slice(None),) 3617 3618 data, axes = squeeze(data, axes) 3619 assert len(axes) == ndim 3620 # take slice from z axis if needed 3621 slices = () 3622 if ndim > ndim_need: 3623 for i, a in enumerate(axes): 3624 s = data.shape[i] 3625 if a.id == AxisId("z"): 3626 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3627 data, axes = squeeze(data, axes) 3628 ndim -= 1 3629 break 3630 3631 slices += (slice(None),) 3632 3633 # take slice from any space or time axis 3634 slices = () 3635 3636 for i, a in enumerate(axes): 3637 if ndim <= ndim_need: 3638 break 3639 3640 s = data.shape[i] 3641 assert s > 1 3642 if isinstance( 3643 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3644 ): 3645 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3646 ndim -= 1 3647 3648 slices += (slice(None),) 3649 3650 del slices 3651 data, axes = squeeze(data, axes) 3652 assert len(axes) == ndim 3653 3654 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 3655 raise ValueError( 3656 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 3657 ) 3658 3659 if not has_c_axis: 3660 assert ndim == 2 3661 data = np.repeat(data[:, :, None], 3, axis=2) 3662 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3663 ndim += 1 3664 3665 assert ndim == 3 3666 3667 # transpose axis order such that longest axis comes first... 3668 axis_order: List[int] = list(np.argsort(list(data.shape))) 3669 axis_order.reverse() 3670 # ... and channel axis is last 3671 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3672 axis_order.append(axis_order.pop(c)) 3673 axes = [axes[ao] for ao in axis_order] 3674 data = data.transpose(axis_order) 3675 3676 # h, w = data.shape[:2] 3677 # if h / w in (1.0 or 2.0): 3678 # pass 3679 # elif h / w < 2: 3680 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3681 3682 norm_along = ( 3683 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3684 ) 3685 # normalize the data and map to 8 bit 3686 data = normalize(data, norm_along) 3687 data = (data * 255).astype("uint8") 3688 3689 return data 3690 3691 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3692 assert im0.dtype == im1.dtype == np.uint8 3693 assert im0.shape == im1.shape 3694 assert im0.ndim == 3 3695 N, M, C = im0.shape 3696 assert C == 3 3697 out = np.ones((N, M, C), dtype="uint8") 3698 for c in range(C): 3699 outc = np.tril(im0[..., c]) 3700 mask = outc == 0 3701 outc[mask] = np.triu(im1[..., c])[mask] 3702 out[..., c] = outc 3703 3704 return out 3705 3706 if not inputs: 3707 raise ValueError("Missing test input tensor for cover generation.") 3708 3709 if not outputs: 3710 raise ValueError("Missing test output tensor for cover generation.") 3711 3712 ipt_descr, ipt = inputs[0] 3713 out_descr, out = outputs[0] 3714 3715 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3716 out_img = to_2d_image(out, out_descr.axes) 3717 3718 cover_folder = Path(mkdtemp()) 3719 if ipt_img.shape == out_img.shape: 3720 covers = [cover_folder / "cover.png"] 3721 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3722 else: 3723 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3724 imwrite(covers[0], ipt_img) 3725 imwrite(covers[1], out_img) 3726 3727 return covers