bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 Callable, 17 ClassVar, 18 Dict, 19 Generic, 20 List, 21 Literal, 22 Mapping, 23 NamedTuple, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 Type, 29 TypeVar, 30 Union, 31 cast, 32) 33 34import numpy as np 35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 37from loguru import logger 38from numpy.typing import NDArray 39from pydantic import ( 40 AfterValidator, 41 Discriminator, 42 Field, 43 RootModel, 44 SerializationInfo, 45 SerializerFunctionWrapHandler, 46 StrictInt, 47 Tag, 48 ValidationInfo, 49 WrapSerializer, 50 field_validator, 51 model_serializer, 52 model_validator, 53) 54from typing_extensions import Annotated, Self, assert_never, get_args 55 56from .._internal.common_nodes import ( 57 InvalidDescr, 58 Node, 59 NodeWithExplicitlySetFields, 60) 61from .._internal.constants import DTYPE_LIMITS 62from .._internal.field_warning import issue_warning, warn 63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 64from .._internal.io import FileDescr as FileDescr 65from .._internal.io import ( 66 FileSource, 67 WithSuffix, 68 YamlValue, 69 get_reader, 70 wo_special_file_name, 71) 72from .._internal.io_basics import Sha256 as Sha256 73from .._internal.io_packaging import ( 74 FileDescr_, 75 FileSource_, 76 package_file_descr_serializer, 77) 78from .._internal.io_utils import load_array 79from .._internal.node_converter import Converter 80from .._internal.type_guards import is_dict, is_sequence 81from .._internal.types import ( 82 AbsoluteTolerance, 83 LowerCaseIdentifier, 84 LowerCaseIdentifierAnno, 85 MismatchedElementsPerMillion, 86 RelativeTolerance, 87) 88from .._internal.types import Datetime as Datetime 89from .._internal.types import Identifier as Identifier 90from .._internal.types import NotEmpty as NotEmpty 91from .._internal.types import SiUnit as SiUnit 92from .._internal.url import HttpUrl as HttpUrl 93from .._internal.validation_context import get_validation_context 94from .._internal.validator_annotations import RestrictCharacters 95from .._internal.version_type import Version as Version 96from .._internal.warning_levels import INFO 97from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 98from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 99from ..dataset.v0_3 import DatasetDescr as DatasetDescr 100from ..dataset.v0_3 import DatasetId as DatasetId 101from ..dataset.v0_3 import LinkedDataset as LinkedDataset 102from ..dataset.v0_3 import Uploader as Uploader 103from ..generic.v0_3 import ( 104 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 105) 106from ..generic.v0_3 import Author as Author 107from ..generic.v0_3 import BadgeDescr as BadgeDescr 108from ..generic.v0_3 import CiteEntry as CiteEntry 109from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 110from ..generic.v0_3 import Doi as Doi 111from ..generic.v0_3 import ( 112 FileSource_documentation, 113 GenericModelDescrBase, 114 LinkedResourceBase, 115 _author_conv, # pyright: ignore[reportPrivateUsage] 116 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 117) 118from ..generic.v0_3 import LicenseId as LicenseId 119from ..generic.v0_3 import LinkedResource as LinkedResource 120from ..generic.v0_3 import Maintainer as Maintainer 121from ..generic.v0_3 import OrcidId as OrcidId 122from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 123from ..generic.v0_3 import ResourceId as ResourceId 124from .v0_4 import Author as _Author_v0_4 125from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 126from .v0_4 import CallableFromDepencency as CallableFromDepencency 127from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 128from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 129from .v0_4 import ClipDescr as _ClipDescr_v0_4 130from .v0_4 import ClipKwargs as ClipKwargs 131from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 132from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 133from .v0_4 import KnownRunMode as KnownRunMode 134from .v0_4 import ModelDescr as _ModelDescr_v0_4 135from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 136from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 137from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 138from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 139from .v0_4 import ProcessingKwargs as ProcessingKwargs 140from .v0_4 import RunMode as RunMode 141from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 142from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 143from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 144from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 145from .v0_4 import TensorName as _TensorName_v0_4 146from .v0_4 import WeightsFormat as WeightsFormat 147from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 148from .v0_4 import package_weights 149 150SpaceUnit = Literal[ 151 "attometer", 152 "angstrom", 153 "centimeter", 154 "decimeter", 155 "exameter", 156 "femtometer", 157 "foot", 158 "gigameter", 159 "hectometer", 160 "inch", 161 "kilometer", 162 "megameter", 163 "meter", 164 "micrometer", 165 "mile", 166 "millimeter", 167 "nanometer", 168 "parsec", 169 "petameter", 170 "picometer", 171 "terameter", 172 "yard", 173 "yoctometer", 174 "yottameter", 175 "zeptometer", 176 "zettameter", 177] 178"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 179 180TimeUnit = Literal[ 181 "attosecond", 182 "centisecond", 183 "day", 184 "decisecond", 185 "exasecond", 186 "femtosecond", 187 "gigasecond", 188 "hectosecond", 189 "hour", 190 "kilosecond", 191 "megasecond", 192 "microsecond", 193 "millisecond", 194 "minute", 195 "nanosecond", 196 "petasecond", 197 "picosecond", 198 "second", 199 "terasecond", 200 "yoctosecond", 201 "yottasecond", 202 "zeptosecond", 203 "zettasecond", 204] 205"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 206 207AxisType = Literal["batch", "channel", "index", "time", "space"] 208 209_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 210 "b": "batch", 211 "t": "time", 212 "i": "index", 213 "c": "channel", 214 "x": "space", 215 "y": "space", 216 "z": "space", 217} 218 219_AXIS_ID_MAP = { 220 "b": "batch", 221 "t": "time", 222 "i": "index", 223 "c": "channel", 224} 225 226 227class TensorId(LowerCaseIdentifier): 228 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 229 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 230 ] 231 232 233def _normalize_axis_id(a: str): 234 a = str(a) 235 normalized = _AXIS_ID_MAP.get(a, a) 236 if a != normalized: 237 logger.opt(depth=3).warning( 238 "Normalized axis id from '{}' to '{}'.", a, normalized 239 ) 240 return normalized 241 242 243class AxisId(LowerCaseIdentifier): 244 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 245 Annotated[ 246 LowerCaseIdentifierAnno, 247 MaxLen(16), 248 AfterValidator(_normalize_axis_id), 249 ] 250 ] 251 252 253def _is_batch(a: str) -> bool: 254 return str(a) == "batch" 255 256 257def _is_not_batch(a: str) -> bool: 258 return not _is_batch(a) 259 260 261NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 262 263PostprocessingId = Literal[ 264 "binarize", 265 "clip", 266 "ensure_dtype", 267 "fixed_zero_mean_unit_variance", 268 "scale_linear", 269 "scale_mean_variance", 270 "scale_range", 271 "sigmoid", 272 "zero_mean_unit_variance", 273] 274PreprocessingId = Literal[ 275 "binarize", 276 "clip", 277 "ensure_dtype", 278 "scale_linear", 279 "sigmoid", 280 "zero_mean_unit_variance", 281 "scale_range", 282] 283 284 285SAME_AS_TYPE = "<same as type>" 286 287 288ParameterizedSize_N = int 289""" 290Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 291""" 292 293 294class ParameterizedSize(Node): 295 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 296 297 - **min** and **step** are given by the model description. 298 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 299 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 300 This allows to adjust the axis size more generically. 301 """ 302 303 N: ClassVar[Type[int]] = ParameterizedSize_N 304 """Positive integer to parameterize this axis""" 305 306 min: Annotated[int, Gt(0)] 307 step: Annotated[int, Gt(0)] 308 309 def validate_size(self, size: int) -> int: 310 if size < self.min: 311 raise ValueError(f"size {size} < {self.min}") 312 if (size - self.min) % self.step != 0: 313 raise ValueError( 314 f"axis of size {size} is not parameterized by `min + n*step` =" 315 + f" `{self.min} + n*{self.step}`" 316 ) 317 318 return size 319 320 def get_size(self, n: ParameterizedSize_N) -> int: 321 return self.min + self.step * n 322 323 def get_n(self, s: int) -> ParameterizedSize_N: 324 """return smallest n parameterizing a size greater or equal than `s`""" 325 return ceil((s - self.min) / self.step) 326 327 328class DataDependentSize(Node): 329 min: Annotated[int, Gt(0)] = 1 330 max: Annotated[Optional[int], Gt(1)] = None 331 332 @model_validator(mode="after") 333 def _validate_max_gt_min(self): 334 if self.max is not None and self.min >= self.max: 335 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 336 337 return self 338 339 def validate_size(self, size: int) -> int: 340 if size < self.min: 341 raise ValueError(f"size {size} < {self.min}") 342 343 if self.max is not None and size > self.max: 344 raise ValueError(f"size {size} > {self.max}") 345 346 return size 347 348 349class SizeReference(Node): 350 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 351 352 `axis.size = reference.size * reference.scale / axis.scale + offset` 353 354 Note: 355 1. The axis and the referenced axis need to have the same unit (or no unit). 356 2. Batch axes may not be referenced. 357 3. Fractions are rounded down. 358 4. If the reference axis is `concatenable` the referencing axis is assumed to be 359 `concatenable` as well with the same block order. 360 361 Example: 362 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 363 Let's assume that we want to express the image height h in relation to its width w 364 instead of only accepting input images of exactly 100*49 pixels 365 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 366 367 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 368 >>> h = SpaceInputAxis( 369 ... id=AxisId("h"), 370 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 371 ... unit="millimeter", 372 ... scale=4, 373 ... ) 374 >>> print(h.size.get_size(h, w)) 375 49 376 377 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 378 """ 379 380 tensor_id: TensorId 381 """tensor id of the reference axis""" 382 383 axis_id: AxisId 384 """axis id of the reference axis""" 385 386 offset: StrictInt = 0 387 388 def get_size( 389 self, 390 axis: Union[ 391 ChannelAxis, 392 IndexInputAxis, 393 IndexOutputAxis, 394 TimeInputAxis, 395 SpaceInputAxis, 396 TimeOutputAxis, 397 TimeOutputAxisWithHalo, 398 SpaceOutputAxis, 399 SpaceOutputAxisWithHalo, 400 ], 401 ref_axis: Union[ 402 ChannelAxis, 403 IndexInputAxis, 404 IndexOutputAxis, 405 TimeInputAxis, 406 SpaceInputAxis, 407 TimeOutputAxis, 408 TimeOutputAxisWithHalo, 409 SpaceOutputAxis, 410 SpaceOutputAxisWithHalo, 411 ], 412 n: ParameterizedSize_N = 0, 413 ref_size: Optional[int] = None, 414 ): 415 """Compute the concrete size for a given axis and its reference axis. 416 417 Args: 418 axis: The axis this `SizeReference` is the size of. 419 ref_axis: The reference axis to compute the size from. 420 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 421 and no fixed **ref_size** is given, 422 **n** is used to compute the size of the parameterized **ref_axis**. 423 ref_size: Overwrite the reference size instead of deriving it from 424 **ref_axis** 425 (**ref_axis.scale** is still used; any given **n** is ignored). 426 """ 427 assert ( 428 axis.size == self 429 ), "Given `axis.size` is not defined by this `SizeReference`" 430 431 assert ( 432 ref_axis.id == self.axis_id 433 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 434 435 assert axis.unit == ref_axis.unit, ( 436 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 437 f" but {axis.unit}!={ref_axis.unit}" 438 ) 439 if ref_size is None: 440 if isinstance(ref_axis.size, (int, float)): 441 ref_size = ref_axis.size 442 elif isinstance(ref_axis.size, ParameterizedSize): 443 ref_size = ref_axis.size.get_size(n) 444 elif isinstance(ref_axis.size, DataDependentSize): 445 raise ValueError( 446 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 447 ) 448 elif isinstance(ref_axis.size, SizeReference): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be sized by a" 451 + " `SizeReference` itself." 452 ) 453 else: 454 assert_never(ref_axis.size) 455 456 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 457 458 @staticmethod 459 def _get_unit( 460 axis: Union[ 461 ChannelAxis, 462 IndexInputAxis, 463 IndexOutputAxis, 464 TimeInputAxis, 465 SpaceInputAxis, 466 TimeOutputAxis, 467 TimeOutputAxisWithHalo, 468 SpaceOutputAxis, 469 SpaceOutputAxisWithHalo, 470 ], 471 ): 472 return axis.unit 473 474 475class AxisBase(NodeWithExplicitlySetFields): 476 id: AxisId 477 """An axis id unique across all axes of one tensor.""" 478 479 description: Annotated[str, MaxLen(128)] = "" 480 481 482class WithHalo(Node): 483 halo: Annotated[int, Ge(1)] 484 """The halo should be cropped from the output tensor to avoid boundary effects. 485 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 486 To document a halo that is already cropped by the model use `size.offset` instead.""" 487 488 size: Annotated[ 489 SizeReference, 490 Field( 491 examples=[ 492 10, 493 SizeReference( 494 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 495 ).model_dump(mode="json"), 496 ] 497 ), 498 ] 499 """reference to another axis with an optional offset (see `SizeReference`)""" 500 501 502BATCH_AXIS_ID = AxisId("batch") 503 504 505class BatchAxis(AxisBase): 506 implemented_type: ClassVar[Literal["batch"]] = "batch" 507 if TYPE_CHECKING: 508 type: Literal["batch"] = "batch" 509 else: 510 type: Literal["batch"] 511 512 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 513 size: Optional[Literal[1]] = None 514 """The batch size may be fixed to 1, 515 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 516 517 @property 518 def scale(self): 519 return 1.0 520 521 @property 522 def concatenable(self): 523 return True 524 525 @property 526 def unit(self): 527 return None 528 529 530class ChannelAxis(AxisBase): 531 implemented_type: ClassVar[Literal["channel"]] = "channel" 532 if TYPE_CHECKING: 533 type: Literal["channel"] = "channel" 534 else: 535 type: Literal["channel"] 536 537 id: NonBatchAxisId = AxisId("channel") 538 channel_names: NotEmpty[List[Identifier]] 539 540 @property 541 def size(self) -> int: 542 return len(self.channel_names) 543 544 @property 545 def concatenable(self): 546 return False 547 548 @property 549 def scale(self) -> float: 550 return 1.0 551 552 @property 553 def unit(self): 554 return None 555 556 557class IndexAxisBase(AxisBase): 558 implemented_type: ClassVar[Literal["index"]] = "index" 559 if TYPE_CHECKING: 560 type: Literal["index"] = "index" 561 else: 562 type: Literal["index"] 563 564 id: NonBatchAxisId = AxisId("index") 565 566 @property 567 def scale(self) -> float: 568 return 1.0 569 570 @property 571 def unit(self): 572 return None 573 574 575class _WithInputAxisSize(Node): 576 size: Annotated[ 577 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 578 Field( 579 examples=[ 580 10, 581 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 582 SizeReference( 583 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 584 ).model_dump(mode="json"), 585 ] 586 ), 587 ] 588 """The size/length of this axis can be specified as 589 - fixed integer 590 - parameterized series of valid sizes (`ParameterizedSize`) 591 - reference to another axis with an optional offset (`SizeReference`) 592 """ 593 594 595class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 596 concatenable: bool = False 597 """If a model has a `concatenable` input axis, it can be processed blockwise, 598 splitting a longer sample axis into blocks matching its input tensor description. 599 Output axes are concatenable if they have a `SizeReference` to a concatenable 600 input axis. 601 """ 602 603 604class IndexOutputAxis(IndexAxisBase): 605 size: Annotated[ 606 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 607 Field( 608 examples=[ 609 10, 610 SizeReference( 611 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 612 ).model_dump(mode="json"), 613 ] 614 ), 615 ] 616 """The size/length of this axis can be specified as 617 - fixed integer 618 - reference to another axis with an optional offset (`SizeReference`) 619 - data dependent size using `DataDependentSize` (size is only known after model inference) 620 """ 621 622 623class TimeAxisBase(AxisBase): 624 implemented_type: ClassVar[Literal["time"]] = "time" 625 if TYPE_CHECKING: 626 type: Literal["time"] = "time" 627 else: 628 type: Literal["time"] 629 630 id: NonBatchAxisId = AxisId("time") 631 unit: Optional[TimeUnit] = None 632 scale: Annotated[float, Gt(0)] = 1.0 633 634 635class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 636 concatenable: bool = False 637 """If a model has a `concatenable` input axis, it can be processed blockwise, 638 splitting a longer sample axis into blocks matching its input tensor description. 639 Output axes are concatenable if they have a `SizeReference` to a concatenable 640 input axis. 641 """ 642 643 644class SpaceAxisBase(AxisBase): 645 implemented_type: ClassVar[Literal["space"]] = "space" 646 if TYPE_CHECKING: 647 type: Literal["space"] = "space" 648 else: 649 type: Literal["space"] 650 651 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 652 unit: Optional[SpaceUnit] = None 653 scale: Annotated[float, Gt(0)] = 1.0 654 655 656class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 657 concatenable: bool = False 658 """If a model has a `concatenable` input axis, it can be processed blockwise, 659 splitting a longer sample axis into blocks matching its input tensor description. 660 Output axes are concatenable if they have a `SizeReference` to a concatenable 661 input axis. 662 """ 663 664 665INPUT_AXIS_TYPES = ( 666 BatchAxis, 667 ChannelAxis, 668 IndexInputAxis, 669 TimeInputAxis, 670 SpaceInputAxis, 671) 672"""intended for isinstance comparisons in py<3.10""" 673 674_InputAxisUnion = Union[ 675 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 676] 677InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 678 679 680class _WithOutputAxisSize(Node): 681 size: Annotated[ 682 Union[Annotated[int, Gt(0)], SizeReference], 683 Field( 684 examples=[ 685 10, 686 SizeReference( 687 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 688 ).model_dump(mode="json"), 689 ] 690 ), 691 ] 692 """The size/length of this axis can be specified as 693 - fixed integer 694 - reference to another axis with an optional offset (see `SizeReference`) 695 """ 696 697 698class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 699 pass 700 701 702class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 703 pass 704 705 706def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 707 if isinstance(v, dict): 708 return "with_halo" if "halo" in v else "wo_halo" 709 else: 710 return "with_halo" if hasattr(v, "halo") else "wo_halo" 711 712 713_TimeOutputAxisUnion = Annotated[ 714 Union[ 715 Annotated[TimeOutputAxis, Tag("wo_halo")], 716 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 717 ], 718 Discriminator(_get_halo_axis_discriminator_value), 719] 720 721 722class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 723 pass 724 725 726class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 727 pass 728 729 730_SpaceOutputAxisUnion = Annotated[ 731 Union[ 732 Annotated[SpaceOutputAxis, Tag("wo_halo")], 733 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 734 ], 735 Discriminator(_get_halo_axis_discriminator_value), 736] 737 738 739_OutputAxisUnion = Union[ 740 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 741] 742OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 743 744OUTPUT_AXIS_TYPES = ( 745 BatchAxis, 746 ChannelAxis, 747 IndexOutputAxis, 748 TimeOutputAxis, 749 TimeOutputAxisWithHalo, 750 SpaceOutputAxis, 751 SpaceOutputAxisWithHalo, 752) 753"""intended for isinstance comparisons in py<3.10""" 754 755 756AnyAxis = Union[InputAxis, OutputAxis] 757 758ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 759"""intended for isinstance comparisons in py<3.10""" 760 761TVs = Union[ 762 NotEmpty[List[int]], 763 NotEmpty[List[float]], 764 NotEmpty[List[bool]], 765 NotEmpty[List[str]], 766] 767 768 769NominalOrOrdinalDType = Literal[ 770 "float32", 771 "float64", 772 "uint8", 773 "int8", 774 "uint16", 775 "int16", 776 "uint32", 777 "int32", 778 "uint64", 779 "int64", 780 "bool", 781] 782 783 784class NominalOrOrdinalDataDescr(Node): 785 values: TVs 786 """A fixed set of nominal or an ascending sequence of ordinal values. 787 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 788 String `values` are interpreted as labels for tensor values 0, ..., N. 789 Note: as YAML 1.2 does not natively support a "set" datatype, 790 nominal values should be given as a sequence (aka list/array) as well. 791 """ 792 793 type: Annotated[ 794 NominalOrOrdinalDType, 795 Field( 796 examples=[ 797 "float32", 798 "uint8", 799 "uint16", 800 "int64", 801 "bool", 802 ], 803 ), 804 ] = "uint8" 805 806 @model_validator(mode="after") 807 def _validate_values_match_type( 808 self, 809 ) -> Self: 810 incompatible: List[Any] = [] 811 for v in self.values: 812 if self.type == "bool": 813 if not isinstance(v, bool): 814 incompatible.append(v) 815 elif self.type in DTYPE_LIMITS: 816 if ( 817 isinstance(v, (int, float)) 818 and ( 819 v < DTYPE_LIMITS[self.type].min 820 or v > DTYPE_LIMITS[self.type].max 821 ) 822 or (isinstance(v, str) and "uint" not in self.type) 823 or (isinstance(v, float) and "int" in self.type) 824 ): 825 incompatible.append(v) 826 else: 827 incompatible.append(v) 828 829 if len(incompatible) == 5: 830 incompatible.append("...") 831 break 832 833 if incompatible: 834 raise ValueError( 835 f"data type '{self.type}' incompatible with values {incompatible}" 836 ) 837 838 return self 839 840 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 841 842 @property 843 def range(self): 844 if isinstance(self.values[0], str): 845 return 0, len(self.values) - 1 846 else: 847 return min(self.values), max(self.values) 848 849 850IntervalOrRatioDType = Literal[ 851 "float32", 852 "float64", 853 "uint8", 854 "int8", 855 "uint16", 856 "int16", 857 "uint32", 858 "int32", 859 "uint64", 860 "int64", 861] 862 863 864class IntervalOrRatioDataDescr(Node): 865 type: Annotated[ # todo: rename to dtype 866 IntervalOrRatioDType, 867 Field( 868 examples=["float32", "float64", "uint8", "uint16"], 869 ), 870 ] = "float32" 871 range: Tuple[Optional[float], Optional[float]] = ( 872 None, 873 None, 874 ) 875 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 876 `None` corresponds to min/max of what can be expressed by **type**.""" 877 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 878 scale: float = 1.0 879 """Scale for data on an interval (or ratio) scale.""" 880 offset: Optional[float] = None 881 """Offset for data on a ratio scale.""" 882 883 @model_validator(mode="before") 884 def _replace_inf(cls, data: Any): 885 if is_dict(data): 886 if "range" in data and is_sequence(data["range"]): 887 forbidden = ( 888 "inf", 889 "-inf", 890 ".inf", 891 "-.inf", 892 float("inf"), 893 float("-inf"), 894 ) 895 if any(v in forbidden for v in data["range"]): 896 issue_warning("replaced 'inf' value", value=data["range"]) 897 898 data["range"] = tuple( 899 (None if v in forbidden else v) for v in data["range"] 900 ) 901 902 return data 903 904 905TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 906 907 908class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 909 """processing base class""" 910 911 912class BinarizeKwargs(ProcessingKwargs): 913 """key word arguments for `BinarizeDescr`""" 914 915 threshold: float 916 """The fixed threshold""" 917 918 919class BinarizeAlongAxisKwargs(ProcessingKwargs): 920 """key word arguments for `BinarizeDescr`""" 921 922 threshold: NotEmpty[List[float]] 923 """The fixed threshold values along `axis`""" 924 925 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 926 """The `threshold` axis""" 927 928 929class BinarizeDescr(ProcessingDescrBase): 930 """Binarize the tensor with a fixed threshold. 931 932 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 933 will be set to one, values below the threshold to zero. 934 935 Examples: 936 - in YAML 937 ```yaml 938 postprocessing: 939 - id: binarize 940 kwargs: 941 axis: 'channel' 942 threshold: [0.25, 0.5, 0.75] 943 ``` 944 - in Python: 945 >>> postprocessing = [BinarizeDescr( 946 ... kwargs=BinarizeAlongAxisKwargs( 947 ... axis=AxisId('channel'), 948 ... threshold=[0.25, 0.5, 0.75], 949 ... ) 950 ... )] 951 """ 952 953 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 954 if TYPE_CHECKING: 955 id: Literal["binarize"] = "binarize" 956 else: 957 id: Literal["binarize"] 958 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 959 960 961class ClipDescr(ProcessingDescrBase): 962 """Set tensor values below min to min and above max to max. 963 964 See `ScaleRangeDescr` for examples. 965 """ 966 967 implemented_id: ClassVar[Literal["clip"]] = "clip" 968 if TYPE_CHECKING: 969 id: Literal["clip"] = "clip" 970 else: 971 id: Literal["clip"] 972 973 kwargs: ClipKwargs 974 975 976class EnsureDtypeKwargs(ProcessingKwargs): 977 """key word arguments for `EnsureDtypeDescr`""" 978 979 dtype: Literal[ 980 "float32", 981 "float64", 982 "uint8", 983 "int8", 984 "uint16", 985 "int16", 986 "uint32", 987 "int32", 988 "uint64", 989 "int64", 990 "bool", 991 ] 992 993 994class EnsureDtypeDescr(ProcessingDescrBase): 995 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 996 997 This can for example be used to ensure the inner neural network model gets a 998 different input tensor data type than the fully described bioimage.io model does. 999 1000 Examples: 1001 The described bioimage.io model (incl. preprocessing) accepts any 1002 float32-compatible tensor, normalizes it with percentiles and clipping and then 1003 casts it to uint8, which is what the neural network in this example expects. 1004 - in YAML 1005 ```yaml 1006 inputs: 1007 - data: 1008 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1009 preprocessing: 1010 - id: scale_range 1011 kwargs: 1012 axes: ['y', 'x'] 1013 max_percentile: 99.8 1014 min_percentile: 5.0 1015 - id: clip 1016 kwargs: 1017 min: 0.0 1018 max: 1.0 1019 - id: ensure_dtype # the neural network of the model requires uint8 1020 kwargs: 1021 dtype: uint8 1022 ``` 1023 - in Python: 1024 >>> preprocessing = [ 1025 ... ScaleRangeDescr( 1026 ... kwargs=ScaleRangeKwargs( 1027 ... axes= (AxisId('y'), AxisId('x')), 1028 ... max_percentile= 99.8, 1029 ... min_percentile= 5.0, 1030 ... ) 1031 ... ), 1032 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1033 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1034 ... ] 1035 """ 1036 1037 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1038 if TYPE_CHECKING: 1039 id: Literal["ensure_dtype"] = "ensure_dtype" 1040 else: 1041 id: Literal["ensure_dtype"] 1042 1043 kwargs: EnsureDtypeKwargs 1044 1045 1046class ScaleLinearKwargs(ProcessingKwargs): 1047 """Key word arguments for `ScaleLinearDescr`""" 1048 1049 gain: float = 1.0 1050 """multiplicative factor""" 1051 1052 offset: float = 0.0 1053 """additive term""" 1054 1055 @model_validator(mode="after") 1056 def _validate(self) -> Self: 1057 if self.gain == 1.0 and self.offset == 0.0: 1058 raise ValueError( 1059 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1060 + " != 0.0." 1061 ) 1062 1063 return self 1064 1065 1066class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1067 """Key word arguments for `ScaleLinearDescr`""" 1068 1069 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1070 """The axis of gain and offset values.""" 1071 1072 gain: Union[float, NotEmpty[List[float]]] = 1.0 1073 """multiplicative factor""" 1074 1075 offset: Union[float, NotEmpty[List[float]]] = 0.0 1076 """additive term""" 1077 1078 @model_validator(mode="after") 1079 def _validate(self) -> Self: 1080 1081 if isinstance(self.gain, list): 1082 if isinstance(self.offset, list): 1083 if len(self.gain) != len(self.offset): 1084 raise ValueError( 1085 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1086 ) 1087 else: 1088 self.offset = [float(self.offset)] * len(self.gain) 1089 elif isinstance(self.offset, list): 1090 self.gain = [float(self.gain)] * len(self.offset) 1091 else: 1092 raise ValueError( 1093 "Do not specify an `axis` for scalar gain and offset values." 1094 ) 1095 1096 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1097 raise ValueError( 1098 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1099 + " != 0.0." 1100 ) 1101 1102 return self 1103 1104 1105class ScaleLinearDescr(ProcessingDescrBase): 1106 """Fixed linear scaling. 1107 1108 Examples: 1109 1. Scale with scalar gain and offset 1110 - in YAML 1111 ```yaml 1112 preprocessing: 1113 - id: scale_linear 1114 kwargs: 1115 gain: 2.0 1116 offset: 3.0 1117 ``` 1118 - in Python: 1119 >>> preprocessing = [ 1120 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1121 ... ] 1122 1123 2. Independent scaling along an axis 1124 - in YAML 1125 ```yaml 1126 preprocessing: 1127 - id: scale_linear 1128 kwargs: 1129 axis: 'channel' 1130 gain: [1.0, 2.0, 3.0] 1131 ``` 1132 - in Python: 1133 >>> preprocessing = [ 1134 ... ScaleLinearDescr( 1135 ... kwargs=ScaleLinearAlongAxisKwargs( 1136 ... axis=AxisId("channel"), 1137 ... gain=[1.0, 2.0, 3.0], 1138 ... ) 1139 ... ) 1140 ... ] 1141 1142 """ 1143 1144 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1145 if TYPE_CHECKING: 1146 id: Literal["scale_linear"] = "scale_linear" 1147 else: 1148 id: Literal["scale_linear"] 1149 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1150 1151 1152class SigmoidDescr(ProcessingDescrBase): 1153 """The logistic sigmoid funciton, a.k.a. expit function. 1154 1155 Examples: 1156 - in YAML 1157 ```yaml 1158 postprocessing: 1159 - id: sigmoid 1160 ``` 1161 - in Python: 1162 >>> postprocessing = [SigmoidDescr()] 1163 """ 1164 1165 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1166 if TYPE_CHECKING: 1167 id: Literal["sigmoid"] = "sigmoid" 1168 else: 1169 id: Literal["sigmoid"] 1170 1171 @property 1172 def kwargs(self) -> ProcessingKwargs: 1173 """empty kwargs""" 1174 return ProcessingKwargs() 1175 1176 1177class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1178 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1179 1180 mean: float 1181 """The mean value to normalize with.""" 1182 1183 std: Annotated[float, Ge(1e-6)] 1184 """The standard deviation value to normalize with.""" 1185 1186 1187class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1188 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1189 1190 mean: NotEmpty[List[float]] 1191 """The mean value(s) to normalize with.""" 1192 1193 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1194 """The standard deviation value(s) to normalize with. 1195 Size must match `mean` values.""" 1196 1197 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1198 """The axis of the mean/std values to normalize each entry along that dimension 1199 separately.""" 1200 1201 @model_validator(mode="after") 1202 def _mean_and_std_match(self) -> Self: 1203 if len(self.mean) != len(self.std): 1204 raise ValueError( 1205 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1206 + " must match." 1207 ) 1208 1209 return self 1210 1211 1212class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1213 """Subtract a given mean and divide by the standard deviation. 1214 1215 Normalize with fixed, precomputed values for 1216 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1217 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1218 axes. 1219 1220 Examples: 1221 1. scalar value for whole tensor 1222 - in YAML 1223 ```yaml 1224 preprocessing: 1225 - id: fixed_zero_mean_unit_variance 1226 kwargs: 1227 mean: 103.5 1228 std: 13.7 1229 ``` 1230 - in Python 1231 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1232 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1233 ... )] 1234 1235 2. independently along an axis 1236 - in YAML 1237 ```yaml 1238 preprocessing: 1239 - id: fixed_zero_mean_unit_variance 1240 kwargs: 1241 axis: channel 1242 mean: [101.5, 102.5, 103.5] 1243 std: [11.7, 12.7, 13.7] 1244 ``` 1245 - in Python 1246 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1247 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1248 ... axis=AxisId("channel"), 1249 ... mean=[101.5, 102.5, 103.5], 1250 ... std=[11.7, 12.7, 13.7], 1251 ... ) 1252 ... )] 1253 """ 1254 1255 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1256 "fixed_zero_mean_unit_variance" 1257 ) 1258 if TYPE_CHECKING: 1259 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1260 else: 1261 id: Literal["fixed_zero_mean_unit_variance"] 1262 1263 kwargs: Union[ 1264 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1265 ] 1266 1267 1268class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1269 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1270 1271 axes: Annotated[ 1272 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1273 ] = None 1274 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1275 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1276 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1277 To normalize each sample independently leave out the 'batch' axis. 1278 Default: Scale all axes jointly.""" 1279 1280 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1281 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1282 1283 1284class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1285 """Subtract mean and divide by variance. 1286 1287 Examples: 1288 Subtract tensor mean and variance 1289 - in YAML 1290 ```yaml 1291 preprocessing: 1292 - id: zero_mean_unit_variance 1293 ``` 1294 - in Python 1295 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1296 """ 1297 1298 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1299 "zero_mean_unit_variance" 1300 ) 1301 if TYPE_CHECKING: 1302 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1303 else: 1304 id: Literal["zero_mean_unit_variance"] 1305 1306 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1307 default_factory=ZeroMeanUnitVarianceKwargs 1308 ) 1309 1310 1311class ScaleRangeKwargs(ProcessingKwargs): 1312 """key word arguments for `ScaleRangeDescr` 1313 1314 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1315 this processing step normalizes data to the [0, 1] intervall. 1316 For other percentiles the normalized values will partially be outside the [0, 1] 1317 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1318 normalized values to a range. 1319 """ 1320 1321 axes: Annotated[ 1322 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1323 ] = None 1324 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1325 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1326 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1327 To normalize samples independently, leave out the "batch" axis. 1328 Default: Scale all axes jointly.""" 1329 1330 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1331 """The lower percentile used to determine the value to align with zero.""" 1332 1333 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1334 """The upper percentile used to determine the value to align with one. 1335 Has to be bigger than `min_percentile`. 1336 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1337 accepting percentiles specified in the range 0.0 to 1.0.""" 1338 1339 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1340 """Epsilon for numeric stability. 1341 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1342 with `v_lower,v_upper` values at the respective percentiles.""" 1343 1344 reference_tensor: Optional[TensorId] = None 1345 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1346 For any tensor in `inputs` only input tensor references are allowed.""" 1347 1348 @field_validator("max_percentile", mode="after") 1349 @classmethod 1350 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1351 if (min_p := info.data["min_percentile"]) >= value: 1352 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1353 1354 return value 1355 1356 1357class ScaleRangeDescr(ProcessingDescrBase): 1358 """Scale with percentiles. 1359 1360 Examples: 1361 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1362 - in YAML 1363 ```yaml 1364 preprocessing: 1365 - id: scale_range 1366 kwargs: 1367 axes: ['y', 'x'] 1368 max_percentile: 99.8 1369 min_percentile: 5.0 1370 ``` 1371 - in Python 1372 >>> preprocessing = [ 1373 ... ScaleRangeDescr( 1374 ... kwargs=ScaleRangeKwargs( 1375 ... axes= (AxisId('y'), AxisId('x')), 1376 ... max_percentile= 99.8, 1377 ... min_percentile= 5.0, 1378 ... ) 1379 ... ), 1380 ... ClipDescr( 1381 ... kwargs=ClipKwargs( 1382 ... min=0.0, 1383 ... max=1.0, 1384 ... ) 1385 ... ), 1386 ... ] 1387 1388 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1389 - in YAML 1390 ```yaml 1391 preprocessing: 1392 - id: scale_range 1393 kwargs: 1394 axes: ['y', 'x'] 1395 max_percentile: 99.8 1396 min_percentile: 5.0 1397 - id: scale_range 1398 - id: clip 1399 kwargs: 1400 min: 0.0 1401 max: 1.0 1402 ``` 1403 - in Python 1404 >>> preprocessing = [ScaleRangeDescr( 1405 ... kwargs=ScaleRangeKwargs( 1406 ... axes= (AxisId('y'), AxisId('x')), 1407 ... max_percentile= 99.8, 1408 ... min_percentile= 5.0, 1409 ... ) 1410 ... )] 1411 1412 """ 1413 1414 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1415 if TYPE_CHECKING: 1416 id: Literal["scale_range"] = "scale_range" 1417 else: 1418 id: Literal["scale_range"] 1419 kwargs: ScaleRangeKwargs 1420 1421 1422class ScaleMeanVarianceKwargs(ProcessingKwargs): 1423 """key word arguments for `ScaleMeanVarianceKwargs`""" 1424 1425 reference_tensor: TensorId 1426 """Name of tensor to match.""" 1427 1428 axes: Annotated[ 1429 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1430 ] = None 1431 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1432 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1433 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1434 To normalize samples independently, leave out the 'batch' axis. 1435 Default: Scale all axes jointly.""" 1436 1437 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1438 """Epsilon for numeric stability: 1439 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1440 1441 1442class ScaleMeanVarianceDescr(ProcessingDescrBase): 1443 """Scale a tensor's data distribution to match another tensor's mean/std. 1444 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1445 """ 1446 1447 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1448 if TYPE_CHECKING: 1449 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1450 else: 1451 id: Literal["scale_mean_variance"] 1452 kwargs: ScaleMeanVarianceKwargs 1453 1454 1455PreprocessingDescr = Annotated[ 1456 Union[ 1457 BinarizeDescr, 1458 ClipDescr, 1459 EnsureDtypeDescr, 1460 ScaleLinearDescr, 1461 SigmoidDescr, 1462 FixedZeroMeanUnitVarianceDescr, 1463 ZeroMeanUnitVarianceDescr, 1464 ScaleRangeDescr, 1465 ], 1466 Discriminator("id"), 1467] 1468PostprocessingDescr = Annotated[ 1469 Union[ 1470 BinarizeDescr, 1471 ClipDescr, 1472 EnsureDtypeDescr, 1473 ScaleLinearDescr, 1474 SigmoidDescr, 1475 FixedZeroMeanUnitVarianceDescr, 1476 ZeroMeanUnitVarianceDescr, 1477 ScaleRangeDescr, 1478 ScaleMeanVarianceDescr, 1479 ], 1480 Discriminator("id"), 1481] 1482 1483IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1484 1485 1486class TensorDescrBase(Node, Generic[IO_AxisT]): 1487 id: TensorId 1488 """Tensor id. No duplicates are allowed.""" 1489 1490 description: Annotated[str, MaxLen(128)] = "" 1491 """free text description""" 1492 1493 axes: NotEmpty[Sequence[IO_AxisT]] 1494 """tensor axes""" 1495 1496 @property 1497 def shape(self): 1498 return tuple(a.size for a in self.axes) 1499 1500 @field_validator("axes", mode="after", check_fields=False) 1501 @classmethod 1502 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1503 batch_axes = [a for a in axes if a.type == "batch"] 1504 if len(batch_axes) > 1: 1505 raise ValueError( 1506 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1507 ) 1508 1509 seen_ids: Set[AxisId] = set() 1510 duplicate_axes_ids: Set[AxisId] = set() 1511 for a in axes: 1512 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1513 1514 if duplicate_axes_ids: 1515 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1516 1517 return axes 1518 1519 test_tensor: FileDescr_ 1520 """An example tensor to use for testing. 1521 Using the model with the test input tensors is expected to yield the test output tensors. 1522 Each test tensor has be a an ndarray in the 1523 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1524 The file extension must be '.npy'.""" 1525 1526 sample_tensor: Optional[FileDescr_] = None 1527 """A sample tensor to illustrate a possible input/output for the model, 1528 The sample image primarily serves to inform a human user about an example use case 1529 and is typically stored as .hdf5, .png or .tiff. 1530 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1531 (numpy's `.npy` format is not supported). 1532 The image dimensionality has to match the number of axes specified in this tensor description. 1533 """ 1534 1535 @model_validator(mode="after") 1536 def _validate_sample_tensor(self) -> Self: 1537 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1538 return self 1539 1540 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1541 tensor: NDArray[Any] = imread( 1542 reader.read(), 1543 extension=PurePosixPath(reader.original_file_name).suffix, 1544 ) 1545 n_dims = len(tensor.squeeze().shape) 1546 n_dims_min = n_dims_max = len(self.axes) 1547 1548 for a in self.axes: 1549 if isinstance(a, BatchAxis): 1550 n_dims_min -= 1 1551 elif isinstance(a.size, int): 1552 if a.size == 1: 1553 n_dims_min -= 1 1554 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1555 if a.size.min == 1: 1556 n_dims_min -= 1 1557 elif isinstance(a.size, SizeReference): 1558 if a.size.offset < 2: 1559 # size reference may result in singleton axis 1560 n_dims_min -= 1 1561 else: 1562 assert_never(a.size) 1563 1564 n_dims_min = max(0, n_dims_min) 1565 if n_dims < n_dims_min or n_dims > n_dims_max: 1566 raise ValueError( 1567 f"Expected sample tensor to have {n_dims_min} to" 1568 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1569 ) 1570 1571 return self 1572 1573 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1574 IntervalOrRatioDataDescr() 1575 ) 1576 """Description of the tensor's data values, optionally per channel. 1577 If specified per channel, the data `type` needs to match across channels.""" 1578 1579 @property 1580 def dtype( 1581 self, 1582 ) -> Literal[ 1583 "float32", 1584 "float64", 1585 "uint8", 1586 "int8", 1587 "uint16", 1588 "int16", 1589 "uint32", 1590 "int32", 1591 "uint64", 1592 "int64", 1593 "bool", 1594 ]: 1595 """dtype as specified under `data.type` or `data[i].type`""" 1596 if isinstance(self.data, collections.abc.Sequence): 1597 return self.data[0].type 1598 else: 1599 return self.data.type 1600 1601 @field_validator("data", mode="after") 1602 @classmethod 1603 def _check_data_type_across_channels( 1604 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1605 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1606 if not isinstance(value, list): 1607 return value 1608 1609 dtypes = {t.type for t in value} 1610 if len(dtypes) > 1: 1611 raise ValueError( 1612 "Tensor data descriptions per channel need to agree in their data" 1613 + f" `type`, but found {dtypes}." 1614 ) 1615 1616 return value 1617 1618 @model_validator(mode="after") 1619 def _check_data_matches_channelaxis(self) -> Self: 1620 if not isinstance(self.data, (list, tuple)): 1621 return self 1622 1623 for a in self.axes: 1624 if isinstance(a, ChannelAxis): 1625 size = a.size 1626 assert isinstance(size, int) 1627 break 1628 else: 1629 return self 1630 1631 if len(self.data) != size: 1632 raise ValueError( 1633 f"Got tensor data descriptions for {len(self.data)} channels, but" 1634 + f" '{a.id}' axis has size {size}." 1635 ) 1636 1637 return self 1638 1639 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1640 if len(array.shape) != len(self.axes): 1641 raise ValueError( 1642 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1643 + f" incompatible with {len(self.axes)} axes." 1644 ) 1645 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1646 1647 1648class InputTensorDescr(TensorDescrBase[InputAxis]): 1649 id: TensorId = TensorId("input") 1650 """Input tensor id. 1651 No duplicates are allowed across all inputs and outputs.""" 1652 1653 optional: bool = False 1654 """indicates that this tensor may be `None`""" 1655 1656 preprocessing: List[PreprocessingDescr] = Field( 1657 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1658 ) 1659 1660 """Description of how this input should be preprocessed. 1661 1662 notes: 1663 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1664 to ensure an input tensor's data type matches the input tensor's data description. 1665 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1666 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1667 changing the data type. 1668 """ 1669 1670 @model_validator(mode="after") 1671 def _validate_preprocessing_kwargs(self) -> Self: 1672 axes_ids = [a.id for a in self.axes] 1673 for p in self.preprocessing: 1674 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1675 if kwargs_axes is None: 1676 continue 1677 1678 if not isinstance(kwargs_axes, collections.abc.Sequence): 1679 raise ValueError( 1680 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1681 ) 1682 1683 if any(a not in axes_ids for a in kwargs_axes): 1684 raise ValueError( 1685 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1686 ) 1687 1688 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1689 dtype = self.data.type 1690 else: 1691 dtype = self.data[0].type 1692 1693 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1694 if not self.preprocessing or not isinstance( 1695 self.preprocessing[0], EnsureDtypeDescr 1696 ): 1697 self.preprocessing.insert( 1698 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1699 ) 1700 1701 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1702 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1703 self.preprocessing.append( 1704 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1705 ) 1706 1707 return self 1708 1709 1710def convert_axes( 1711 axes: str, 1712 *, 1713 shape: Union[ 1714 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1715 ], 1716 tensor_type: Literal["input", "output"], 1717 halo: Optional[Sequence[int]], 1718 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1719): 1720 ret: List[AnyAxis] = [] 1721 for i, a in enumerate(axes): 1722 axis_type = _AXIS_TYPE_MAP.get(a, a) 1723 if axis_type == "batch": 1724 ret.append(BatchAxis()) 1725 continue 1726 1727 scale = 1.0 1728 if isinstance(shape, _ParameterizedInputShape_v0_4): 1729 if shape.step[i] == 0: 1730 size = shape.min[i] 1731 else: 1732 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1733 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1734 ref_t = str(shape.reference_tensor) 1735 if ref_t.count(".") == 1: 1736 t_id, orig_a_id = ref_t.split(".") 1737 else: 1738 t_id = ref_t 1739 orig_a_id = a 1740 1741 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1742 if not (orig_scale := shape.scale[i]): 1743 # old way to insert a new axis dimension 1744 size = int(2 * shape.offset[i]) 1745 else: 1746 scale = 1 / orig_scale 1747 if axis_type in ("channel", "index"): 1748 # these axes no longer have a scale 1749 offset_from_scale = orig_scale * size_refs.get( 1750 _TensorName_v0_4(t_id), {} 1751 ).get(orig_a_id, 0) 1752 else: 1753 offset_from_scale = 0 1754 size = SizeReference( 1755 tensor_id=TensorId(t_id), 1756 axis_id=AxisId(a_id), 1757 offset=int(offset_from_scale + 2 * shape.offset[i]), 1758 ) 1759 else: 1760 size = shape[i] 1761 1762 if axis_type == "time": 1763 if tensor_type == "input": 1764 ret.append(TimeInputAxis(size=size, scale=scale)) 1765 else: 1766 assert not isinstance(size, ParameterizedSize) 1767 if halo is None: 1768 ret.append(TimeOutputAxis(size=size, scale=scale)) 1769 else: 1770 assert not isinstance(size, int) 1771 ret.append( 1772 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1773 ) 1774 1775 elif axis_type == "index": 1776 if tensor_type == "input": 1777 ret.append(IndexInputAxis(size=size)) 1778 else: 1779 if isinstance(size, ParameterizedSize): 1780 size = DataDependentSize(min=size.min) 1781 1782 ret.append(IndexOutputAxis(size=size)) 1783 elif axis_type == "channel": 1784 assert not isinstance(size, ParameterizedSize) 1785 if isinstance(size, SizeReference): 1786 warnings.warn( 1787 "Conversion of channel size from an implicit output shape may be" 1788 + " wrong" 1789 ) 1790 ret.append( 1791 ChannelAxis( 1792 channel_names=[ 1793 Identifier(f"channel{i}") for i in range(size.offset) 1794 ] 1795 ) 1796 ) 1797 else: 1798 ret.append( 1799 ChannelAxis( 1800 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1801 ) 1802 ) 1803 elif axis_type == "space": 1804 if tensor_type == "input": 1805 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1806 else: 1807 assert not isinstance(size, ParameterizedSize) 1808 if halo is None or halo[i] == 0: 1809 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1810 elif isinstance(size, int): 1811 raise NotImplementedError( 1812 f"output axis with halo and fixed size (here {size}) not allowed" 1813 ) 1814 else: 1815 ret.append( 1816 SpaceOutputAxisWithHalo( 1817 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1818 ) 1819 ) 1820 1821 return ret 1822 1823 1824def _axes_letters_to_ids( 1825 axes: Optional[str], 1826) -> Optional[List[AxisId]]: 1827 if axes is None: 1828 return None 1829 1830 return [AxisId(a) for a in axes] 1831 1832 1833def _get_complement_v04_axis( 1834 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1835) -> Optional[AxisId]: 1836 if axes is None: 1837 return None 1838 1839 non_complement_axes = set(axes) | {"b"} 1840 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1841 if len(complement_axes) > 1: 1842 raise ValueError( 1843 f"Expected none or a single complement axis, but axes '{axes}' " 1844 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1845 ) 1846 1847 return None if not complement_axes else AxisId(complement_axes[0]) 1848 1849 1850def _convert_proc( 1851 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1852 tensor_axes: Sequence[str], 1853) -> Union[PreprocessingDescr, PostprocessingDescr]: 1854 if isinstance(p, _BinarizeDescr_v0_4): 1855 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1856 elif isinstance(p, _ClipDescr_v0_4): 1857 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1858 elif isinstance(p, _SigmoidDescr_v0_4): 1859 return SigmoidDescr() 1860 elif isinstance(p, _ScaleLinearDescr_v0_4): 1861 axes = _axes_letters_to_ids(p.kwargs.axes) 1862 if p.kwargs.axes is None: 1863 axis = None 1864 else: 1865 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1866 1867 if axis is None: 1868 assert not isinstance(p.kwargs.gain, list) 1869 assert not isinstance(p.kwargs.offset, list) 1870 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1871 else: 1872 kwargs = ScaleLinearAlongAxisKwargs( 1873 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1874 ) 1875 return ScaleLinearDescr(kwargs=kwargs) 1876 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1877 return ScaleMeanVarianceDescr( 1878 kwargs=ScaleMeanVarianceKwargs( 1879 axes=_axes_letters_to_ids(p.kwargs.axes), 1880 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1881 eps=p.kwargs.eps, 1882 ) 1883 ) 1884 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1885 if p.kwargs.mode == "fixed": 1886 mean = p.kwargs.mean 1887 std = p.kwargs.std 1888 assert mean is not None 1889 assert std is not None 1890 1891 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1892 1893 if axis is None: 1894 return FixedZeroMeanUnitVarianceDescr( 1895 kwargs=FixedZeroMeanUnitVarianceKwargs( 1896 mean=mean, std=std # pyright: ignore[reportArgumentType] 1897 ) 1898 ) 1899 else: 1900 if not isinstance(mean, list): 1901 mean = [float(mean)] 1902 if not isinstance(std, list): 1903 std = [float(std)] 1904 1905 return FixedZeroMeanUnitVarianceDescr( 1906 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1907 axis=axis, mean=mean, std=std 1908 ) 1909 ) 1910 1911 else: 1912 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1913 if p.kwargs.mode == "per_dataset": 1914 axes = [AxisId("batch")] + axes 1915 if not axes: 1916 axes = None 1917 return ZeroMeanUnitVarianceDescr( 1918 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1919 ) 1920 1921 elif isinstance(p, _ScaleRangeDescr_v0_4): 1922 return ScaleRangeDescr( 1923 kwargs=ScaleRangeKwargs( 1924 axes=_axes_letters_to_ids(p.kwargs.axes), 1925 min_percentile=p.kwargs.min_percentile, 1926 max_percentile=p.kwargs.max_percentile, 1927 eps=p.kwargs.eps, 1928 ) 1929 ) 1930 else: 1931 assert_never(p) 1932 1933 1934class _InputTensorConv( 1935 Converter[ 1936 _InputTensorDescr_v0_4, 1937 InputTensorDescr, 1938 FileSource_, 1939 Optional[FileSource_], 1940 Mapping[_TensorName_v0_4, Mapping[str, int]], 1941 ] 1942): 1943 def _convert( 1944 self, 1945 src: _InputTensorDescr_v0_4, 1946 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1947 test_tensor: FileSource_, 1948 sample_tensor: Optional[FileSource_], 1949 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1950 ) -> "InputTensorDescr | dict[str, Any]": 1951 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1952 src.axes, 1953 shape=src.shape, 1954 tensor_type="input", 1955 halo=None, 1956 size_refs=size_refs, 1957 ) 1958 prep: List[PreprocessingDescr] = [] 1959 for p in src.preprocessing: 1960 cp = _convert_proc(p, src.axes) 1961 assert not isinstance(cp, ScaleMeanVarianceDescr) 1962 prep.append(cp) 1963 1964 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 1965 1966 return tgt( 1967 axes=axes, 1968 id=TensorId(str(src.name)), 1969 test_tensor=FileDescr(source=test_tensor), 1970 sample_tensor=( 1971 None if sample_tensor is None else FileDescr(source=sample_tensor) 1972 ), 1973 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 1974 preprocessing=prep, 1975 ) 1976 1977 1978_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 1979 1980 1981class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1982 id: TensorId = TensorId("output") 1983 """Output tensor id. 1984 No duplicates are allowed across all inputs and outputs.""" 1985 1986 postprocessing: List[PostprocessingDescr] = Field( 1987 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 1988 ) 1989 """Description of how this output should be postprocessed. 1990 1991 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1992 If not given this is added to cast to this tensor's `data.type`. 1993 """ 1994 1995 @model_validator(mode="after") 1996 def _validate_postprocessing_kwargs(self) -> Self: 1997 axes_ids = [a.id for a in self.axes] 1998 for p in self.postprocessing: 1999 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2000 if kwargs_axes is None: 2001 continue 2002 2003 if not isinstance(kwargs_axes, collections.abc.Sequence): 2004 raise ValueError( 2005 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2006 ) 2007 2008 if any(a not in axes_ids for a in kwargs_axes): 2009 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2010 2011 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2012 dtype = self.data.type 2013 else: 2014 dtype = self.data[0].type 2015 2016 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2017 if not self.postprocessing or not isinstance( 2018 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2019 ): 2020 self.postprocessing.append( 2021 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2022 ) 2023 return self 2024 2025 2026class _OutputTensorConv( 2027 Converter[ 2028 _OutputTensorDescr_v0_4, 2029 OutputTensorDescr, 2030 FileSource_, 2031 Optional[FileSource_], 2032 Mapping[_TensorName_v0_4, Mapping[str, int]], 2033 ] 2034): 2035 def _convert( 2036 self, 2037 src: _OutputTensorDescr_v0_4, 2038 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 2039 test_tensor: FileSource_, 2040 sample_tensor: Optional[FileSource_], 2041 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2042 ) -> "OutputTensorDescr | dict[str, Any]": 2043 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2044 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2045 src.axes, 2046 shape=src.shape, 2047 tensor_type="output", 2048 halo=src.halo, 2049 size_refs=size_refs, 2050 ) 2051 data_descr: Dict[str, Any] = dict(type=src.data_type) 2052 if data_descr["type"] == "bool": 2053 data_descr["values"] = [False, True] 2054 2055 return tgt( 2056 axes=axes, 2057 id=TensorId(str(src.name)), 2058 test_tensor=FileDescr(source=test_tensor), 2059 sample_tensor=( 2060 None if sample_tensor is None else FileDescr(source=sample_tensor) 2061 ), 2062 data=data_descr, # pyright: ignore[reportArgumentType] 2063 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2064 ) 2065 2066 2067_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2068 2069 2070TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2071 2072 2073def validate_tensors( 2074 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2075 tensor_origin: Literal[ 2076 "test_tensor" 2077 ], # for more precise error messages, e.g. 'test_tensor' 2078): 2079 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2080 2081 def e_msg(d: TensorDescr): 2082 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2083 2084 for descr, array in tensors.values(): 2085 try: 2086 axis_sizes = descr.get_axis_sizes_for_array(array) 2087 except ValueError as e: 2088 raise ValueError(f"{e_msg(descr)} {e}") 2089 else: 2090 all_tensor_axes[descr.id] = { 2091 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2092 } 2093 2094 for descr, array in tensors.values(): 2095 if descr.dtype in ("float32", "float64"): 2096 invalid_test_tensor_dtype = array.dtype.name not in ( 2097 "float32", 2098 "float64", 2099 "uint8", 2100 "int8", 2101 "uint16", 2102 "int16", 2103 "uint32", 2104 "int32", 2105 "uint64", 2106 "int64", 2107 ) 2108 else: 2109 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2110 2111 if invalid_test_tensor_dtype: 2112 raise ValueError( 2113 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2114 + f" match described dtype '{descr.dtype}'" 2115 ) 2116 2117 if array.min() > -1e-4 and array.max() < 1e-4: 2118 raise ValueError( 2119 "Output values are too small for reliable testing." 2120 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2121 ) 2122 2123 for a in descr.axes: 2124 actual_size = all_tensor_axes[descr.id][a.id][1] 2125 if a.size is None: 2126 continue 2127 2128 if isinstance(a.size, int): 2129 if actual_size != a.size: 2130 raise ValueError( 2131 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2132 + f"has incompatible size {actual_size}, expected {a.size}" 2133 ) 2134 elif isinstance(a.size, ParameterizedSize): 2135 _ = a.size.validate_size(actual_size) 2136 elif isinstance(a.size, DataDependentSize): 2137 _ = a.size.validate_size(actual_size) 2138 elif isinstance(a.size, SizeReference): 2139 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2140 if ref_tensor_axes is None: 2141 raise ValueError( 2142 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2143 + f" reference '{a.size.tensor_id}'" 2144 ) 2145 2146 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2147 if ref_axis is None or ref_size is None: 2148 raise ValueError( 2149 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2150 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2151 ) 2152 2153 if a.unit != ref_axis.unit: 2154 raise ValueError( 2155 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2156 + " axis and reference axis to have the same `unit`, but" 2157 + f" {a.unit}!={ref_axis.unit}" 2158 ) 2159 2160 if actual_size != ( 2161 expected_size := ( 2162 ref_size * ref_axis.scale / a.scale + a.size.offset 2163 ) 2164 ): 2165 raise ValueError( 2166 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2167 + f" {actual_size} invalid for referenced size {ref_size};" 2168 + f" expected {expected_size}" 2169 ) 2170 else: 2171 assert_never(a.size) 2172 2173 2174FileDescr_dependencies = Annotated[ 2175 FileDescr_, 2176 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2177 Field(examples=[dict(source="environment.yaml")]), 2178] 2179 2180 2181class _ArchitectureCallableDescr(Node): 2182 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2183 """Identifier of the callable that returns a torch.nn.Module instance.""" 2184 2185 kwargs: Dict[str, YamlValue] = Field( 2186 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 2187 ) 2188 """key word arguments for the `callable`""" 2189 2190 2191class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2192 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2193 """Architecture source file""" 2194 2195 @model_serializer(mode="wrap", when_used="unless-none") 2196 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2197 return package_file_descr_serializer(self, nxt, info) 2198 2199 2200class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2201 import_from: str 2202 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2203 2204 2205class _ArchFileConv( 2206 Converter[ 2207 _CallableFromFile_v0_4, 2208 ArchitectureFromFileDescr, 2209 Optional[Sha256], 2210 Dict[str, Any], 2211 ] 2212): 2213 def _convert( 2214 self, 2215 src: _CallableFromFile_v0_4, 2216 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2217 sha256: Optional[Sha256], 2218 kwargs: Dict[str, Any], 2219 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2220 if src.startswith("http") and src.count(":") == 2: 2221 http, source, callable_ = src.split(":") 2222 source = ":".join((http, source)) 2223 elif not src.startswith("http") and src.count(":") == 1: 2224 source, callable_ = src.split(":") 2225 else: 2226 source = str(src) 2227 callable_ = str(src) 2228 return tgt( 2229 callable=Identifier(callable_), 2230 source=cast(FileSource_, source), 2231 sha256=sha256, 2232 kwargs=kwargs, 2233 ) 2234 2235 2236_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2237 2238 2239class _ArchLibConv( 2240 Converter[ 2241 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2242 ] 2243): 2244 def _convert( 2245 self, 2246 src: _CallableFromDepencency_v0_4, 2247 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2248 kwargs: Dict[str, Any], 2249 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2250 *mods, callable_ = src.split(".") 2251 import_from = ".".join(mods) 2252 return tgt( 2253 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2254 ) 2255 2256 2257_arch_lib_conv = _ArchLibConv( 2258 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2259) 2260 2261 2262class WeightsEntryDescrBase(FileDescr): 2263 type: ClassVar[WeightsFormat] 2264 weights_format_name: ClassVar[str] # human readable 2265 2266 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2267 """Source of the weights file.""" 2268 2269 authors: Optional[List[Author]] = None 2270 """Authors 2271 Either the person(s) that have trained this model resulting in the original weights file. 2272 (If this is the initial weights entry, i.e. it does not have a `parent`) 2273 Or the person(s) who have converted the weights to this weights format. 2274 (If this is a child weight, i.e. it has a `parent` field) 2275 """ 2276 2277 parent: Annotated[ 2278 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2279 ] = None 2280 """The source weights these weights were converted from. 2281 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2282 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2283 All weight entries except one (the initial set of weights resulting from training the model), 2284 need to have this field.""" 2285 2286 comment: str = "" 2287 """A comment about this weights entry, for example how these weights were created.""" 2288 2289 @model_validator(mode="after") 2290 def _validate(self) -> Self: 2291 if self.type == self.parent: 2292 raise ValueError("Weights entry can't be it's own parent.") 2293 2294 return self 2295 2296 @model_serializer(mode="wrap", when_used="unless-none") 2297 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2298 return package_file_descr_serializer(self, nxt, info) 2299 2300 2301class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2302 type = "keras_hdf5" 2303 weights_format_name: ClassVar[str] = "Keras HDF5" 2304 tensorflow_version: Version 2305 """TensorFlow version used to create these weights.""" 2306 2307 2308class OnnxWeightsDescr(WeightsEntryDescrBase): 2309 type = "onnx" 2310 weights_format_name: ClassVar[str] = "ONNX" 2311 opset_version: Annotated[int, Ge(7)] 2312 """ONNX opset version""" 2313 2314 2315class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2316 type = "pytorch_state_dict" 2317 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2318 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2319 pytorch_version: Version 2320 """Version of the PyTorch library used. 2321 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2322 """ 2323 dependencies: Optional[FileDescr_dependencies] = None 2324 """Custom depencies beyond pytorch described in a Conda environment file. 2325 Allows to specify custom dependencies, see conda docs: 2326 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2327 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2328 2329 The conda environment file should include pytorch and any version pinning has to be compatible with 2330 **pytorch_version**. 2331 """ 2332 2333 2334class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2335 type = "tensorflow_js" 2336 weights_format_name: ClassVar[str] = "Tensorflow.js" 2337 tensorflow_version: Version 2338 """Version of the TensorFlow library used.""" 2339 2340 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2341 """The multi-file weights. 2342 All required files/folders should be a zip archive.""" 2343 2344 2345class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2346 type = "tensorflow_saved_model_bundle" 2347 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2348 tensorflow_version: Version 2349 """Version of the TensorFlow library used.""" 2350 2351 dependencies: Optional[FileDescr_dependencies] = None 2352 """Custom dependencies beyond tensorflow. 2353 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2354 2355 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2356 """The multi-file weights. 2357 All required files/folders should be a zip archive.""" 2358 2359 2360class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2361 type = "torchscript" 2362 weights_format_name: ClassVar[str] = "TorchScript" 2363 pytorch_version: Version 2364 """Version of the PyTorch library used.""" 2365 2366 2367class WeightsDescr(Node): 2368 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2369 onnx: Optional[OnnxWeightsDescr] = None 2370 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2371 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2372 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2373 None 2374 ) 2375 torchscript: Optional[TorchscriptWeightsDescr] = None 2376 2377 @model_validator(mode="after") 2378 def check_entries(self) -> Self: 2379 entries = {wtype for wtype, entry in self if entry is not None} 2380 2381 if not entries: 2382 raise ValueError("Missing weights entry") 2383 2384 entries_wo_parent = { 2385 wtype 2386 for wtype, entry in self 2387 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2388 } 2389 if len(entries_wo_parent) != 1: 2390 issue_warning( 2391 "Exactly one weights entry may not specify the `parent` field (got" 2392 + " {value}). That entry is considered the original set of model weights." 2393 + " Other weight formats are created through conversion of the orignal or" 2394 + " already converted weights. They have to reference the weights format" 2395 + " they were converted from as their `parent`.", 2396 value=len(entries_wo_parent), 2397 field="weights", 2398 ) 2399 2400 for wtype, entry in self: 2401 if entry is None: 2402 continue 2403 2404 assert hasattr(entry, "type") 2405 assert hasattr(entry, "parent") 2406 assert wtype == entry.type 2407 if ( 2408 entry.parent is not None and entry.parent not in entries 2409 ): # self reference checked for `parent` field 2410 raise ValueError( 2411 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2412 + f" formats: {entries}" 2413 ) 2414 2415 return self 2416 2417 def __getitem__( 2418 self, 2419 key: Literal[ 2420 "keras_hdf5", 2421 "onnx", 2422 "pytorch_state_dict", 2423 "tensorflow_js", 2424 "tensorflow_saved_model_bundle", 2425 "torchscript", 2426 ], 2427 ): 2428 if key == "keras_hdf5": 2429 ret = self.keras_hdf5 2430 elif key == "onnx": 2431 ret = self.onnx 2432 elif key == "pytorch_state_dict": 2433 ret = self.pytorch_state_dict 2434 elif key == "tensorflow_js": 2435 ret = self.tensorflow_js 2436 elif key == "tensorflow_saved_model_bundle": 2437 ret = self.tensorflow_saved_model_bundle 2438 elif key == "torchscript": 2439 ret = self.torchscript 2440 else: 2441 raise KeyError(key) 2442 2443 if ret is None: 2444 raise KeyError(key) 2445 2446 return ret 2447 2448 @property 2449 def available_formats(self): 2450 return { 2451 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2452 **({} if self.onnx is None else {"onnx": self.onnx}), 2453 **( 2454 {} 2455 if self.pytorch_state_dict is None 2456 else {"pytorch_state_dict": self.pytorch_state_dict} 2457 ), 2458 **( 2459 {} 2460 if self.tensorflow_js is None 2461 else {"tensorflow_js": self.tensorflow_js} 2462 ), 2463 **( 2464 {} 2465 if self.tensorflow_saved_model_bundle is None 2466 else { 2467 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2468 } 2469 ), 2470 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2471 } 2472 2473 @property 2474 def missing_formats(self): 2475 return { 2476 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2477 } 2478 2479 2480class ModelId(ResourceId): 2481 pass 2482 2483 2484class LinkedModel(LinkedResourceBase): 2485 """Reference to a bioimage.io model.""" 2486 2487 id: ModelId 2488 """A valid model `id` from the bioimage.io collection.""" 2489 2490 2491class _DataDepSize(NamedTuple): 2492 min: StrictInt 2493 max: Optional[StrictInt] 2494 2495 2496class _AxisSizes(NamedTuple): 2497 """the lenghts of all axes of model inputs and outputs""" 2498 2499 inputs: Dict[Tuple[TensorId, AxisId], int] 2500 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2501 2502 2503class _TensorSizes(NamedTuple): 2504 """_AxisSizes as nested dicts""" 2505 2506 inputs: Dict[TensorId, Dict[AxisId, int]] 2507 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2508 2509 2510class ReproducibilityTolerance(Node, extra="allow"): 2511 """Describes what small numerical differences -- if any -- may be tolerated 2512 in the generated output when executing in different environments. 2513 2514 A tensor element *output* is considered mismatched to the **test_tensor** if 2515 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2516 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2517 2518 Motivation: 2519 For testing we can request the respective deep learning frameworks to be as 2520 reproducible as possible by setting seeds and chosing deterministic algorithms, 2521 but differences in operating systems, available hardware and installed drivers 2522 may still lead to numerical differences. 2523 """ 2524 2525 relative_tolerance: RelativeTolerance = 1e-3 2526 """Maximum relative tolerance of reproduced test tensor.""" 2527 2528 absolute_tolerance: AbsoluteTolerance = 1e-4 2529 """Maximum absolute tolerance of reproduced test tensor.""" 2530 2531 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2532 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2533 2534 output_ids: Sequence[TensorId] = () 2535 """Limits the output tensor IDs these reproducibility details apply to.""" 2536 2537 weights_formats: Sequence[WeightsFormat] = () 2538 """Limits the weights formats these details apply to.""" 2539 2540 2541class BioimageioConfig(Node, extra="allow"): 2542 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2543 """Tolerances to allow when reproducing the model's test outputs 2544 from the model's test inputs. 2545 Only the first entry matching tensor id and weights format is considered. 2546 """ 2547 2548 2549class Config(Node, extra="allow"): 2550 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 2551 2552 2553class ModelDescr(GenericModelDescrBase): 2554 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2555 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2556 """ 2557 2558 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2559 if TYPE_CHECKING: 2560 format_version: Literal["0.5.4"] = "0.5.4" 2561 else: 2562 format_version: Literal["0.5.4"] 2563 """Version of the bioimage.io model description specification used. 2564 When creating a new model always use the latest micro/patch version described here. 2565 The `format_version` is important for any consumer software to understand how to parse the fields. 2566 """ 2567 2568 implemented_type: ClassVar[Literal["model"]] = "model" 2569 if TYPE_CHECKING: 2570 type: Literal["model"] = "model" 2571 else: 2572 type: Literal["model"] 2573 """Specialized resource type 'model'""" 2574 2575 id: Optional[ModelId] = None 2576 """bioimage.io-wide unique resource identifier 2577 assigned by bioimage.io; version **un**specific.""" 2578 2579 authors: NotEmpty[List[Author]] 2580 """The authors are the creators of the model RDF and the primary points of contact.""" 2581 2582 documentation: FileSource_documentation 2583 """URL or relative path to a markdown file with additional documentation. 2584 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2585 The documentation should include a '#[#] Validation' (sub)section 2586 with details on how to quantitatively validate the model on unseen data.""" 2587 2588 @field_validator("documentation", mode="after") 2589 @classmethod 2590 def _validate_documentation( 2591 cls, value: FileSource_documentation 2592 ) -> FileSource_documentation: 2593 if not get_validation_context().perform_io_checks: 2594 return value 2595 2596 doc_reader = get_reader(value) 2597 doc_content = doc_reader.read().decode(encoding="utf-8") 2598 if not re.search("#.*[vV]alidation", doc_content): 2599 issue_warning( 2600 "No '# Validation' (sub)section found in {value}.", 2601 value=value, 2602 field="documentation", 2603 ) 2604 2605 return value 2606 2607 inputs: NotEmpty[Sequence[InputTensorDescr]] 2608 """Describes the input tensors expected by this model.""" 2609 2610 @field_validator("inputs", mode="after") 2611 @classmethod 2612 def _validate_input_axes( 2613 cls, inputs: Sequence[InputTensorDescr] 2614 ) -> Sequence[InputTensorDescr]: 2615 input_size_refs = cls._get_axes_with_independent_size(inputs) 2616 2617 for i, ipt in enumerate(inputs): 2618 valid_independent_refs: Dict[ 2619 Tuple[TensorId, AxisId], 2620 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2621 ] = { 2622 **{ 2623 (ipt.id, a.id): (ipt, a, a.size) 2624 for a in ipt.axes 2625 if not isinstance(a, BatchAxis) 2626 and isinstance(a.size, (int, ParameterizedSize)) 2627 }, 2628 **input_size_refs, 2629 } 2630 for a, ax in enumerate(ipt.axes): 2631 cls._validate_axis( 2632 "inputs", 2633 i=i, 2634 tensor_id=ipt.id, 2635 a=a, 2636 axis=ax, 2637 valid_independent_refs=valid_independent_refs, 2638 ) 2639 return inputs 2640 2641 @staticmethod 2642 def _validate_axis( 2643 field_name: str, 2644 i: int, 2645 tensor_id: TensorId, 2646 a: int, 2647 axis: AnyAxis, 2648 valid_independent_refs: Dict[ 2649 Tuple[TensorId, AxisId], 2650 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2651 ], 2652 ): 2653 if isinstance(axis, BatchAxis) or isinstance( 2654 axis.size, (int, ParameterizedSize, DataDependentSize) 2655 ): 2656 return 2657 elif not isinstance(axis.size, SizeReference): 2658 assert_never(axis.size) 2659 2660 # validate axis.size SizeReference 2661 ref = (axis.size.tensor_id, axis.size.axis_id) 2662 if ref not in valid_independent_refs: 2663 raise ValueError( 2664 "Invalid tensor axis reference at" 2665 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2666 ) 2667 if ref == (tensor_id, axis.id): 2668 raise ValueError( 2669 "Self-referencing not allowed for" 2670 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2671 ) 2672 if axis.type == "channel": 2673 if valid_independent_refs[ref][1].type != "channel": 2674 raise ValueError( 2675 "A channel axis' size may only reference another fixed size" 2676 + " channel axis." 2677 ) 2678 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2679 ref_size = valid_independent_refs[ref][2] 2680 assert isinstance(ref_size, int), ( 2681 "channel axis ref (another channel axis) has to specify fixed" 2682 + " size" 2683 ) 2684 generated_channel_names = [ 2685 Identifier(axis.channel_names.format(i=i)) 2686 for i in range(1, ref_size + 1) 2687 ] 2688 axis.channel_names = generated_channel_names 2689 2690 if (ax_unit := getattr(axis, "unit", None)) != ( 2691 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2692 ): 2693 raise ValueError( 2694 "The units of an axis and its reference axis need to match, but" 2695 + f" '{ax_unit}' != '{ref_unit}'." 2696 ) 2697 ref_axis = valid_independent_refs[ref][1] 2698 if isinstance(ref_axis, BatchAxis): 2699 raise ValueError( 2700 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2701 + " (a batch axis is not allowed as reference)." 2702 ) 2703 2704 if isinstance(axis, WithHalo): 2705 min_size = axis.size.get_size(axis, ref_axis, n=0) 2706 if (min_size - 2 * axis.halo) < 1: 2707 raise ValueError( 2708 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2709 + f" {axis.halo}." 2710 ) 2711 2712 input_halo = axis.halo * axis.scale / ref_axis.scale 2713 if input_halo != int(input_halo) or input_halo % 2 == 1: 2714 raise ValueError( 2715 f"input_halo {input_halo} (output_halo {axis.halo} *" 2716 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2717 + f" {tensor_id}.{axis.id}." 2718 ) 2719 2720 @model_validator(mode="after") 2721 def _validate_test_tensors(self) -> Self: 2722 if not get_validation_context().perform_io_checks: 2723 return self 2724 2725 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2726 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2727 2728 tensors = { 2729 descr.id: (descr, array) 2730 for descr, array in zip( 2731 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2732 ) 2733 } 2734 validate_tensors(tensors, tensor_origin="test_tensor") 2735 2736 output_arrays = { 2737 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2738 } 2739 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2740 if not rep_tol.absolute_tolerance: 2741 continue 2742 2743 if rep_tol.output_ids: 2744 out_arrays = { 2745 oid: a 2746 for oid, a in output_arrays.items() 2747 if oid in rep_tol.output_ids 2748 } 2749 else: 2750 out_arrays = output_arrays 2751 2752 for out_id, array in out_arrays.items(): 2753 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2754 raise ValueError( 2755 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2756 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2757 + f" (1% of the maximum value of the test tensor '{out_id}')" 2758 ) 2759 2760 return self 2761 2762 @model_validator(mode="after") 2763 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2764 ipt_refs = {t.id for t in self.inputs} 2765 out_refs = {t.id for t in self.outputs} 2766 for ipt in self.inputs: 2767 for p in ipt.preprocessing: 2768 ref = p.kwargs.get("reference_tensor") 2769 if ref is None: 2770 continue 2771 if ref not in ipt_refs: 2772 raise ValueError( 2773 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2774 + f" references are: {ipt_refs}." 2775 ) 2776 2777 for out in self.outputs: 2778 for p in out.postprocessing: 2779 ref = p.kwargs.get("reference_tensor") 2780 if ref is None: 2781 continue 2782 2783 if ref not in ipt_refs and ref not in out_refs: 2784 raise ValueError( 2785 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2786 + f" are: {ipt_refs | out_refs}." 2787 ) 2788 2789 return self 2790 2791 # TODO: use validate funcs in validate_test_tensors 2792 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2793 2794 name: Annotated[ 2795 Annotated[ 2796 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2797 ], 2798 MinLen(5), 2799 MaxLen(128), 2800 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2801 ] 2802 """A human-readable name of this model. 2803 It should be no longer than 64 characters 2804 and may only contain letter, number, underscore, minus, parentheses and spaces. 2805 We recommend to chose a name that refers to the model's task and image modality. 2806 """ 2807 2808 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2809 """Describes the output tensors.""" 2810 2811 @field_validator("outputs", mode="after") 2812 @classmethod 2813 def _validate_tensor_ids( 2814 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2815 ) -> Sequence[OutputTensorDescr]: 2816 tensor_ids = [ 2817 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2818 ] 2819 duplicate_tensor_ids: List[str] = [] 2820 seen: Set[str] = set() 2821 for t in tensor_ids: 2822 if t in seen: 2823 duplicate_tensor_ids.append(t) 2824 2825 seen.add(t) 2826 2827 if duplicate_tensor_ids: 2828 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2829 2830 return outputs 2831 2832 @staticmethod 2833 def _get_axes_with_parameterized_size( 2834 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2835 ): 2836 return { 2837 f"{t.id}.{a.id}": (t, a, a.size) 2838 for t in io 2839 for a in t.axes 2840 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2841 } 2842 2843 @staticmethod 2844 def _get_axes_with_independent_size( 2845 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2846 ): 2847 return { 2848 (t.id, a.id): (t, a, a.size) 2849 for t in io 2850 for a in t.axes 2851 if not isinstance(a, BatchAxis) 2852 and isinstance(a.size, (int, ParameterizedSize)) 2853 } 2854 2855 @field_validator("outputs", mode="after") 2856 @classmethod 2857 def _validate_output_axes( 2858 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2859 ) -> List[OutputTensorDescr]: 2860 input_size_refs = cls._get_axes_with_independent_size( 2861 info.data.get("inputs", []) 2862 ) 2863 output_size_refs = cls._get_axes_with_independent_size(outputs) 2864 2865 for i, out in enumerate(outputs): 2866 valid_independent_refs: Dict[ 2867 Tuple[TensorId, AxisId], 2868 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2869 ] = { 2870 **{ 2871 (out.id, a.id): (out, a, a.size) 2872 for a in out.axes 2873 if not isinstance(a, BatchAxis) 2874 and isinstance(a.size, (int, ParameterizedSize)) 2875 }, 2876 **input_size_refs, 2877 **output_size_refs, 2878 } 2879 for a, ax in enumerate(out.axes): 2880 cls._validate_axis( 2881 "outputs", 2882 i, 2883 out.id, 2884 a, 2885 ax, 2886 valid_independent_refs=valid_independent_refs, 2887 ) 2888 2889 return outputs 2890 2891 packaged_by: List[Author] = Field( 2892 default_factory=cast(Callable[[], List[Author]], list) 2893 ) 2894 """The persons that have packaged and uploaded this model. 2895 Only required if those persons differ from the `authors`.""" 2896 2897 parent: Optional[LinkedModel] = None 2898 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2899 2900 @model_validator(mode="after") 2901 def _validate_parent_is_not_self(self) -> Self: 2902 if self.parent is not None and self.parent.id == self.id: 2903 raise ValueError("A model description may not reference itself as parent.") 2904 2905 return self 2906 2907 run_mode: Annotated[ 2908 Optional[RunMode], 2909 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2910 ] = None 2911 """Custom run mode for this model: for more complex prediction procedures like test time 2912 data augmentation that currently cannot be expressed in the specification. 2913 No standard run modes are defined yet.""" 2914 2915 timestamp: Datetime = Field(default_factory=Datetime.now) 2916 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2917 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2918 (In Python a datetime object is valid, too).""" 2919 2920 training_data: Annotated[ 2921 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2922 Field(union_mode="left_to_right"), 2923 ] = None 2924 """The dataset used to train this model""" 2925 2926 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2927 """The weights for this model. 2928 Weights can be given for different formats, but should otherwise be equivalent. 2929 The available weight formats determine which consumers can use this model.""" 2930 2931 config: Config = Field(default_factory=Config) 2932 2933 @model_validator(mode="after") 2934 def _add_default_cover(self) -> Self: 2935 if not get_validation_context().perform_io_checks or self.covers: 2936 return self 2937 2938 try: 2939 generated_covers = generate_covers( 2940 [(t, load_array(t.test_tensor)) for t in self.inputs], 2941 [(t, load_array(t.test_tensor)) for t in self.outputs], 2942 ) 2943 except Exception as e: 2944 issue_warning( 2945 "Failed to generate cover image(s): {e}", 2946 value=self.covers, 2947 msg_context=dict(e=e), 2948 field="covers", 2949 ) 2950 else: 2951 self.covers.extend(generated_covers) 2952 2953 return self 2954 2955 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2956 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2957 assert all(isinstance(d, np.ndarray) for d in data) 2958 return data 2959 2960 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2961 data = [load_array(out.test_tensor) for out in self.outputs] 2962 assert all(isinstance(d, np.ndarray) for d in data) 2963 return data 2964 2965 @staticmethod 2966 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2967 batch_size = 1 2968 tensor_with_batchsize: Optional[TensorId] = None 2969 for tid in tensor_sizes: 2970 for aid, s in tensor_sizes[tid].items(): 2971 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2972 continue 2973 2974 if batch_size != 1: 2975 assert tensor_with_batchsize is not None 2976 raise ValueError( 2977 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2978 ) 2979 2980 batch_size = s 2981 tensor_with_batchsize = tid 2982 2983 return batch_size 2984 2985 def get_output_tensor_sizes( 2986 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2987 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2988 """Returns the tensor output sizes for given **input_sizes**. 2989 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2990 Otherwise it might be larger than the actual (valid) output""" 2991 batch_size = self.get_batch_size(input_sizes) 2992 ns = self.get_ns(input_sizes) 2993 2994 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2995 return tensor_sizes.outputs 2996 2997 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2998 """get parameter `n` for each parameterized axis 2999 such that the valid input size is >= the given input size""" 3000 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3001 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3002 for tid in input_sizes: 3003 for aid, s in input_sizes[tid].items(): 3004 size_descr = axes[tid][aid].size 3005 if isinstance(size_descr, ParameterizedSize): 3006 ret[(tid, aid)] = size_descr.get_n(s) 3007 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3008 pass 3009 else: 3010 assert_never(size_descr) 3011 3012 return ret 3013 3014 def get_tensor_sizes( 3015 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3016 ) -> _TensorSizes: 3017 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3018 return _TensorSizes( 3019 { 3020 t: { 3021 aa: axis_sizes.inputs[(tt, aa)] 3022 for tt, aa in axis_sizes.inputs 3023 if tt == t 3024 } 3025 for t in {tt for tt, _ in axis_sizes.inputs} 3026 }, 3027 { 3028 t: { 3029 aa: axis_sizes.outputs[(tt, aa)] 3030 for tt, aa in axis_sizes.outputs 3031 if tt == t 3032 } 3033 for t in {tt for tt, _ in axis_sizes.outputs} 3034 }, 3035 ) 3036 3037 def get_axis_sizes( 3038 self, 3039 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3040 batch_size: Optional[int] = None, 3041 *, 3042 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3043 ) -> _AxisSizes: 3044 """Determine input and output block shape for scale factors **ns** 3045 of parameterized input sizes. 3046 3047 Args: 3048 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3049 that is parameterized as `size = min + n * step`. 3050 batch_size: The desired size of the batch dimension. 3051 If given **batch_size** overwrites any batch size present in 3052 **max_input_shape**. Default 1. 3053 max_input_shape: Limits the derived block shapes. 3054 Each axis for which the input size, parameterized by `n`, is larger 3055 than **max_input_shape** is set to the minimal value `n_min` for which 3056 this is still true. 3057 Use this for small input samples or large values of **ns**. 3058 Or simply whenever you know the full input shape. 3059 3060 Returns: 3061 Resolved axis sizes for model inputs and outputs. 3062 """ 3063 max_input_shape = max_input_shape or {} 3064 if batch_size is None: 3065 for (_t_id, a_id), s in max_input_shape.items(): 3066 if a_id == BATCH_AXIS_ID: 3067 batch_size = s 3068 break 3069 else: 3070 batch_size = 1 3071 3072 all_axes = { 3073 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3074 } 3075 3076 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3077 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3078 3079 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3080 if isinstance(a, BatchAxis): 3081 if (t_descr.id, a.id) in ns: 3082 logger.warning( 3083 "Ignoring unexpected size increment factor (n) for batch axis" 3084 + " of tensor '{}'.", 3085 t_descr.id, 3086 ) 3087 return batch_size 3088 elif isinstance(a.size, int): 3089 if (t_descr.id, a.id) in ns: 3090 logger.warning( 3091 "Ignoring unexpected size increment factor (n) for fixed size" 3092 + " axis '{}' of tensor '{}'.", 3093 a.id, 3094 t_descr.id, 3095 ) 3096 return a.size 3097 elif isinstance(a.size, ParameterizedSize): 3098 if (t_descr.id, a.id) not in ns: 3099 raise ValueError( 3100 "Size increment factor (n) missing for parametrized axis" 3101 + f" '{a.id}' of tensor '{t_descr.id}'." 3102 ) 3103 n = ns[(t_descr.id, a.id)] 3104 s_max = max_input_shape.get((t_descr.id, a.id)) 3105 if s_max is not None: 3106 n = min(n, a.size.get_n(s_max)) 3107 3108 return a.size.get_size(n) 3109 3110 elif isinstance(a.size, SizeReference): 3111 if (t_descr.id, a.id) in ns: 3112 logger.warning( 3113 "Ignoring unexpected size increment factor (n) for axis '{}'" 3114 + " of tensor '{}' with size reference.", 3115 a.id, 3116 t_descr.id, 3117 ) 3118 assert not isinstance(a, BatchAxis) 3119 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3120 assert not isinstance(ref_axis, BatchAxis) 3121 ref_key = (a.size.tensor_id, a.size.axis_id) 3122 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3123 assert ref_size is not None, ref_key 3124 assert not isinstance(ref_size, _DataDepSize), ref_key 3125 return a.size.get_size( 3126 axis=a, 3127 ref_axis=ref_axis, 3128 ref_size=ref_size, 3129 ) 3130 elif isinstance(a.size, DataDependentSize): 3131 if (t_descr.id, a.id) in ns: 3132 logger.warning( 3133 "Ignoring unexpected increment factor (n) for data dependent" 3134 + " size axis '{}' of tensor '{}'.", 3135 a.id, 3136 t_descr.id, 3137 ) 3138 return _DataDepSize(a.size.min, a.size.max) 3139 else: 3140 assert_never(a.size) 3141 3142 # first resolve all , but the `SizeReference` input sizes 3143 for t_descr in self.inputs: 3144 for a in t_descr.axes: 3145 if not isinstance(a.size, SizeReference): 3146 s = get_axis_size(a) 3147 assert not isinstance(s, _DataDepSize) 3148 inputs[t_descr.id, a.id] = s 3149 3150 # resolve all other input axis sizes 3151 for t_descr in self.inputs: 3152 for a in t_descr.axes: 3153 if isinstance(a.size, SizeReference): 3154 s = get_axis_size(a) 3155 assert not isinstance(s, _DataDepSize) 3156 inputs[t_descr.id, a.id] = s 3157 3158 # resolve all output axis sizes 3159 for t_descr in self.outputs: 3160 for a in t_descr.axes: 3161 assert not isinstance(a.size, ParameterizedSize) 3162 s = get_axis_size(a) 3163 outputs[t_descr.id, a.id] = s 3164 3165 return _AxisSizes(inputs=inputs, outputs=outputs) 3166 3167 @model_validator(mode="before") 3168 @classmethod 3169 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3170 cls.convert_from_old_format_wo_validation(data) 3171 return data 3172 3173 @classmethod 3174 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3175 """Convert metadata following an older format version to this classes' format 3176 without validating the result. 3177 """ 3178 if ( 3179 data.get("type") == "model" 3180 and isinstance(fv := data.get("format_version"), str) 3181 and fv.count(".") == 2 3182 ): 3183 fv_parts = fv.split(".") 3184 if any(not p.isdigit() for p in fv_parts): 3185 return 3186 3187 fv_tuple = tuple(map(int, fv_parts)) 3188 3189 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3190 if fv_tuple[:2] in ((0, 3), (0, 4)): 3191 m04 = _ModelDescr_v0_4.load(data) 3192 if isinstance(m04, InvalidDescr): 3193 try: 3194 updated = _model_conv.convert_as_dict( 3195 m04 # pyright: ignore[reportArgumentType] 3196 ) 3197 except Exception as e: 3198 logger.error( 3199 "Failed to convert from invalid model 0.4 description." 3200 + f"\nerror: {e}" 3201 + "\nProceeding with model 0.5 validation without conversion." 3202 ) 3203 updated = None 3204 else: 3205 updated = _model_conv.convert_as_dict(m04) 3206 3207 if updated is not None: 3208 data.clear() 3209 data.update(updated) 3210 3211 elif fv_tuple[:2] == (0, 5): 3212 # bump patch version 3213 data["format_version"] = cls.implemented_format_version 3214 3215 3216class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3217 def _convert( 3218 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3219 ) -> "ModelDescr | dict[str, Any]": 3220 name = "".join( 3221 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3222 for c in src.name 3223 ) 3224 3225 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3226 conv = ( 3227 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3228 ) 3229 return None if auths is None else [conv(a) for a in auths] 3230 3231 if TYPE_CHECKING: 3232 arch_file_conv = _arch_file_conv.convert 3233 arch_lib_conv = _arch_lib_conv.convert 3234 else: 3235 arch_file_conv = _arch_file_conv.convert_as_dict 3236 arch_lib_conv = _arch_lib_conv.convert_as_dict 3237 3238 input_size_refs = { 3239 ipt.name: { 3240 a: s 3241 for a, s in zip( 3242 ipt.axes, 3243 ( 3244 ipt.shape.min 3245 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3246 else ipt.shape 3247 ), 3248 ) 3249 } 3250 for ipt in src.inputs 3251 if ipt.shape 3252 } 3253 output_size_refs = { 3254 **{ 3255 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3256 for out in src.outputs 3257 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3258 }, 3259 **input_size_refs, 3260 } 3261 3262 return tgt( 3263 attachments=( 3264 [] 3265 if src.attachments is None 3266 else [FileDescr(source=f) for f in src.attachments.files] 3267 ), 3268 authors=[ 3269 _author_conv.convert_as_dict(a) for a in src.authors 3270 ], # pyright: ignore[reportArgumentType] 3271 cite=[ 3272 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 3273 ], # pyright: ignore[reportArgumentType] 3274 config=src.config, # pyright: ignore[reportArgumentType] 3275 covers=src.covers, 3276 description=src.description, 3277 documentation=src.documentation, 3278 format_version="0.5.4", 3279 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3280 icon=src.icon, 3281 id=None if src.id is None else ModelId(src.id), 3282 id_emoji=src.id_emoji, 3283 license=src.license, # type: ignore 3284 links=src.links, 3285 maintainers=[ 3286 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 3287 ], # pyright: ignore[reportArgumentType] 3288 name=name, 3289 tags=src.tags, 3290 type=src.type, 3291 uploader=src.uploader, 3292 version=src.version, 3293 inputs=[ # pyright: ignore[reportArgumentType] 3294 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3295 for ipt, tt, st, in zip( 3296 src.inputs, 3297 src.test_inputs, 3298 src.sample_inputs or [None] * len(src.test_inputs), 3299 ) 3300 ], 3301 outputs=[ # pyright: ignore[reportArgumentType] 3302 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3303 for out, tt, st, in zip( 3304 src.outputs, 3305 src.test_outputs, 3306 src.sample_outputs or [None] * len(src.test_outputs), 3307 ) 3308 ], 3309 parent=( 3310 None 3311 if src.parent is None 3312 else LinkedModel( 3313 id=ModelId( 3314 str(src.parent.id) 3315 + ( 3316 "" 3317 if src.parent.version_number is None 3318 else f"/{src.parent.version_number}" 3319 ) 3320 ) 3321 ) 3322 ), 3323 training_data=( 3324 None 3325 if src.training_data is None 3326 else ( 3327 LinkedDataset( 3328 id=DatasetId( 3329 str(src.training_data.id) 3330 + ( 3331 "" 3332 if src.training_data.version_number is None 3333 else f"/{src.training_data.version_number}" 3334 ) 3335 ) 3336 ) 3337 if isinstance(src.training_data, LinkedDataset02) 3338 else src.training_data 3339 ) 3340 ), 3341 packaged_by=[ 3342 _author_conv.convert_as_dict(a) for a in src.packaged_by 3343 ], # pyright: ignore[reportArgumentType] 3344 run_mode=src.run_mode, 3345 timestamp=src.timestamp, 3346 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3347 keras_hdf5=(w := src.weights.keras_hdf5) 3348 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3349 authors=conv_authors(w.authors), 3350 source=w.source, 3351 tensorflow_version=w.tensorflow_version or Version("1.15"), 3352 parent=w.parent, 3353 ), 3354 onnx=(w := src.weights.onnx) 3355 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3356 source=w.source, 3357 authors=conv_authors(w.authors), 3358 parent=w.parent, 3359 opset_version=w.opset_version or 15, 3360 ), 3361 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3362 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3363 source=w.source, 3364 authors=conv_authors(w.authors), 3365 parent=w.parent, 3366 architecture=( 3367 arch_file_conv( 3368 w.architecture, 3369 w.architecture_sha256, 3370 w.kwargs, 3371 ) 3372 if isinstance(w.architecture, _CallableFromFile_v0_4) 3373 else arch_lib_conv(w.architecture, w.kwargs) 3374 ), 3375 pytorch_version=w.pytorch_version or Version("1.10"), 3376 dependencies=( 3377 None 3378 if w.dependencies is None 3379 else (FileDescr if TYPE_CHECKING else dict)( 3380 source=cast( 3381 FileSource, 3382 str(deps := w.dependencies)[ 3383 ( 3384 len("conda:") 3385 if str(deps).startswith("conda:") 3386 else 0 3387 ) : 3388 ], 3389 ) 3390 ) 3391 ), 3392 ), 3393 tensorflow_js=(w := src.weights.tensorflow_js) 3394 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3395 source=w.source, 3396 authors=conv_authors(w.authors), 3397 parent=w.parent, 3398 tensorflow_version=w.tensorflow_version or Version("1.15"), 3399 ), 3400 tensorflow_saved_model_bundle=( 3401 w := src.weights.tensorflow_saved_model_bundle 3402 ) 3403 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3404 authors=conv_authors(w.authors), 3405 parent=w.parent, 3406 source=w.source, 3407 tensorflow_version=w.tensorflow_version or Version("1.15"), 3408 dependencies=( 3409 None 3410 if w.dependencies is None 3411 else (FileDescr if TYPE_CHECKING else dict)( 3412 source=cast( 3413 FileSource, 3414 ( 3415 str(w.dependencies)[len("conda:") :] 3416 if str(w.dependencies).startswith("conda:") 3417 else str(w.dependencies) 3418 ), 3419 ) 3420 ) 3421 ), 3422 ), 3423 torchscript=(w := src.weights.torchscript) 3424 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3425 source=w.source, 3426 authors=conv_authors(w.authors), 3427 parent=w.parent, 3428 pytorch_version=w.pytorch_version or Version("1.10"), 3429 ), 3430 ), 3431 ) 3432 3433 3434_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3435 3436 3437# create better cover images for 3d data and non-image outputs 3438def generate_covers( 3439 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3440 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3441) -> List[Path]: 3442 def squeeze( 3443 data: NDArray[Any], axes: Sequence[AnyAxis] 3444 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3445 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3446 if data.ndim != len(axes): 3447 raise ValueError( 3448 f"tensor shape {data.shape} does not match described axes" 3449 + f" {[a.id for a in axes]}" 3450 ) 3451 3452 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3453 return data.squeeze(), axes 3454 3455 def normalize( 3456 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3457 ) -> NDArray[np.float32]: 3458 data = data.astype("float32") 3459 data -= data.min(axis=axis, keepdims=True) 3460 data /= data.max(axis=axis, keepdims=True) + eps 3461 return data 3462 3463 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3464 original_shape = data.shape 3465 data, axes = squeeze(data, axes) 3466 3467 # take slice fom any batch or index axis if needed 3468 # and convert the first channel axis and take a slice from any additional channel axes 3469 slices: Tuple[slice, ...] = () 3470 ndim = data.ndim 3471 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3472 has_c_axis = False 3473 for i, a in enumerate(axes): 3474 s = data.shape[i] 3475 assert s > 1 3476 if ( 3477 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3478 and ndim > ndim_need 3479 ): 3480 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3481 ndim -= 1 3482 elif isinstance(a, ChannelAxis): 3483 if has_c_axis: 3484 # second channel axis 3485 data = data[slices + (slice(0, 1),)] 3486 ndim -= 1 3487 else: 3488 has_c_axis = True 3489 if s == 2: 3490 # visualize two channels with cyan and magenta 3491 data = np.concatenate( 3492 [ 3493 data[slices + (slice(1, 2),)], 3494 data[slices + (slice(0, 1),)], 3495 ( 3496 data[slices + (slice(0, 1),)] 3497 + data[slices + (slice(1, 2),)] 3498 ) 3499 / 2, # TODO: take maximum instead? 3500 ], 3501 axis=i, 3502 ) 3503 elif data.shape[i] == 3: 3504 pass # visualize 3 channels as RGB 3505 else: 3506 # visualize first 3 channels as RGB 3507 data = data[slices + (slice(3),)] 3508 3509 assert data.shape[i] == 3 3510 3511 slices += (slice(None),) 3512 3513 data, axes = squeeze(data, axes) 3514 assert len(axes) == ndim 3515 # take slice from z axis if needed 3516 slices = () 3517 if ndim > ndim_need: 3518 for i, a in enumerate(axes): 3519 s = data.shape[i] 3520 if a.id == AxisId("z"): 3521 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3522 data, axes = squeeze(data, axes) 3523 ndim -= 1 3524 break 3525 3526 slices += (slice(None),) 3527 3528 # take slice from any space or time axis 3529 slices = () 3530 3531 for i, a in enumerate(axes): 3532 if ndim <= ndim_need: 3533 break 3534 3535 s = data.shape[i] 3536 assert s > 1 3537 if isinstance( 3538 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3539 ): 3540 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3541 ndim -= 1 3542 3543 slices += (slice(None),) 3544 3545 del slices 3546 data, axes = squeeze(data, axes) 3547 assert len(axes) == ndim 3548 3549 if (has_c_axis and ndim != 3) or ndim != 2: 3550 raise ValueError( 3551 f"Failed to construct cover image from shape {original_shape}" 3552 ) 3553 3554 if not has_c_axis: 3555 assert ndim == 2 3556 data = np.repeat(data[:, :, None], 3, axis=2) 3557 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3558 ndim += 1 3559 3560 assert ndim == 3 3561 3562 # transpose axis order such that longest axis comes first... 3563 axis_order: List[int] = list(np.argsort(list(data.shape))) 3564 axis_order.reverse() 3565 # ... and channel axis is last 3566 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3567 axis_order.append(axis_order.pop(c)) 3568 axes = [axes[ao] for ao in axis_order] 3569 data = data.transpose(axis_order) 3570 3571 # h, w = data.shape[:2] 3572 # if h / w in (1.0 or 2.0): 3573 # pass 3574 # elif h / w < 2: 3575 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3576 3577 norm_along = ( 3578 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3579 ) 3580 # normalize the data and map to 8 bit 3581 data = normalize(data, norm_along) 3582 data = (data * 255).astype("uint8") 3583 3584 return data 3585 3586 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3587 assert im0.dtype == im1.dtype == np.uint8 3588 assert im0.shape == im1.shape 3589 assert im0.ndim == 3 3590 N, M, C = im0.shape 3591 assert C == 3 3592 out = np.ones((N, M, C), dtype="uint8") 3593 for c in range(C): 3594 outc = np.tril(im0[..., c]) 3595 mask = outc == 0 3596 outc[mask] = np.triu(im1[..., c])[mask] 3597 out[..., c] = outc 3598 3599 return out 3600 3601 ipt_descr, ipt = inputs[0] 3602 out_descr, out = outputs[0] 3603 3604 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3605 out_img = to_2d_image(out, out_descr.axes) 3606 3607 cover_folder = Path(mkdtemp()) 3608 if ipt_img.shape == out_img.shape: 3609 covers = [cover_folder / "cover.png"] 3610 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3611 else: 3612 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3613 imwrite(covers[0], ipt_img) 3614 imwrite(covers[1], out_img) 3615 3616 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
228class TensorId(LowerCaseIdentifier): 229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 231 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
244class AxisId(LowerCaseIdentifier): 245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 246 Annotated[ 247 LowerCaseIdentifierAnno, 248 MaxLen(16), 249 AfterValidator(_normalize_axis_id), 250 ] 251 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize
.
295class ParameterizedSize(Node): 296 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 297 298 - **min** and **step** are given by the model description. 299 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 300 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 301 This allows to adjust the axis size more generically. 302 """ 303 304 N: ClassVar[Type[int]] = ParameterizedSize_N 305 """Positive integer to parameterize this axis""" 306 307 min: Annotated[int, Gt(0)] 308 step: Annotated[int, Gt(0)] 309 310 def validate_size(self, size: int) -> int: 311 if size < self.min: 312 raise ValueError(f"size {size} < {self.min}") 313 if (size - self.min) % self.step != 0: 314 raise ValueError( 315 f"axis of size {size} is not parameterized by `min + n*step` =" 316 + f" `{self.min} + n*{self.step}`" 317 ) 318 319 return size 320 321 def get_size(self, n: ParameterizedSize_N) -> int: 322 return self.min + self.step * n 323 324 def get_n(self, s: int) -> ParameterizedSize_N: 325 """return smallest n parameterizing a size greater or equal than `s`""" 326 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size
. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
310 def validate_size(self, size: int) -> int: 311 if size < self.min: 312 raise ValueError(f"size {size} < {self.min}") 313 if (size - self.min) % self.step != 0: 314 raise ValueError( 315 f"axis of size {size} is not parameterized by `min + n*step` =" 316 + f" `{self.min} + n*{self.step}`" 317 ) 318 319 return size
324 def get_n(self, s: int) -> ParameterizedSize_N: 325 """return smallest n parameterizing a size greater or equal than `s`""" 326 return ceil((s - self.min) / self.step)
return smallest n parameterizing a size greater or equal than s
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
329class DataDependentSize(Node): 330 min: Annotated[int, Gt(0)] = 1 331 max: Annotated[Optional[int], Gt(1)] = None 332 333 @model_validator(mode="after") 334 def _validate_max_gt_min(self): 335 if self.max is not None and self.min >= self.max: 336 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 337 338 return self 339 340 def validate_size(self, size: int) -> int: 341 if size < self.min: 342 raise ValueError(f"size {size} < {self.min}") 343 344 if self.max is not None and size > self.max: 345 raise ValueError(f"size {size} > {self.max}") 346 347 return size
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
350class SizeReference(Node): 351 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 352 353 `axis.size = reference.size * reference.scale / axis.scale + offset` 354 355 Note: 356 1. The axis and the referenced axis need to have the same unit (or no unit). 357 2. Batch axes may not be referenced. 358 3. Fractions are rounded down. 359 4. If the reference axis is `concatenable` the referencing axis is assumed to be 360 `concatenable` as well with the same block order. 361 362 Example: 363 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 364 Let's assume that we want to express the image height h in relation to its width w 365 instead of only accepting input images of exactly 100*49 pixels 366 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 367 368 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 369 >>> h = SpaceInputAxis( 370 ... id=AxisId("h"), 371 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 372 ... unit="millimeter", 373 ... scale=4, 374 ... ) 375 >>> print(h.size.get_size(h, w)) 376 49 377 378 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 379 """ 380 381 tensor_id: TensorId 382 """tensor id of the reference axis""" 383 384 axis_id: AxisId 385 """axis id of the reference axis""" 386 387 offset: StrictInt = 0 388 389 def get_size( 390 self, 391 axis: Union[ 392 ChannelAxis, 393 IndexInputAxis, 394 IndexOutputAxis, 395 TimeInputAxis, 396 SpaceInputAxis, 397 TimeOutputAxis, 398 TimeOutputAxisWithHalo, 399 SpaceOutputAxis, 400 SpaceOutputAxisWithHalo, 401 ], 402 ref_axis: Union[ 403 ChannelAxis, 404 IndexInputAxis, 405 IndexOutputAxis, 406 TimeInputAxis, 407 SpaceInputAxis, 408 TimeOutputAxis, 409 TimeOutputAxisWithHalo, 410 SpaceOutputAxis, 411 SpaceOutputAxisWithHalo, 412 ], 413 n: ParameterizedSize_N = 0, 414 ref_size: Optional[int] = None, 415 ): 416 """Compute the concrete size for a given axis and its reference axis. 417 418 Args: 419 axis: The axis this `SizeReference` is the size of. 420 ref_axis: The reference axis to compute the size from. 421 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 422 and no fixed **ref_size** is given, 423 **n** is used to compute the size of the parameterized **ref_axis**. 424 ref_size: Overwrite the reference size instead of deriving it from 425 **ref_axis** 426 (**ref_axis.scale** is still used; any given **n** is ignored). 427 """ 428 assert ( 429 axis.size == self 430 ), "Given `axis.size` is not defined by this `SizeReference`" 431 432 assert ( 433 ref_axis.id == self.axis_id 434 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 435 436 assert axis.unit == ref_axis.unit, ( 437 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 438 f" but {axis.unit}!={ref_axis.unit}" 439 ) 440 if ref_size is None: 441 if isinstance(ref_axis.size, (int, float)): 442 ref_size = ref_axis.size 443 elif isinstance(ref_axis.size, ParameterizedSize): 444 ref_size = ref_axis.size.get_size(n) 445 elif isinstance(ref_axis.size, DataDependentSize): 446 raise ValueError( 447 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 448 ) 449 elif isinstance(ref_axis.size, SizeReference): 450 raise ValueError( 451 "Reference axis referenced in `SizeReference` may not be sized by a" 452 + " `SizeReference` itself." 453 ) 454 else: 455 assert_never(ref_axis.size) 456 457 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 458 459 @staticmethod 460 def _get_unit( 461 axis: Union[ 462 ChannelAxis, 463 IndexInputAxis, 464 IndexOutputAxis, 465 TimeInputAxis, 466 SpaceInputAxis, 467 TimeOutputAxis, 468 TimeOutputAxisWithHalo, 469 SpaceOutputAxis, 470 SpaceOutputAxisWithHalo, 471 ], 472 ): 473 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
389 def get_size( 390 self, 391 axis: Union[ 392 ChannelAxis, 393 IndexInputAxis, 394 IndexOutputAxis, 395 TimeInputAxis, 396 SpaceInputAxis, 397 TimeOutputAxis, 398 TimeOutputAxisWithHalo, 399 SpaceOutputAxis, 400 SpaceOutputAxisWithHalo, 401 ], 402 ref_axis: Union[ 403 ChannelAxis, 404 IndexInputAxis, 405 IndexOutputAxis, 406 TimeInputAxis, 407 SpaceInputAxis, 408 TimeOutputAxis, 409 TimeOutputAxisWithHalo, 410 SpaceOutputAxis, 411 SpaceOutputAxisWithHalo, 412 ], 413 n: ParameterizedSize_N = 0, 414 ref_size: Optional[int] = None, 415 ): 416 """Compute the concrete size for a given axis and its reference axis. 417 418 Args: 419 axis: The axis this `SizeReference` is the size of. 420 ref_axis: The reference axis to compute the size from. 421 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 422 and no fixed **ref_size** is given, 423 **n** is used to compute the size of the parameterized **ref_axis**. 424 ref_size: Overwrite the reference size instead of deriving it from 425 **ref_axis** 426 (**ref_axis.scale** is still used; any given **n** is ignored). 427 """ 428 assert ( 429 axis.size == self 430 ), "Given `axis.size` is not defined by this `SizeReference`" 431 432 assert ( 433 ref_axis.id == self.axis_id 434 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 435 436 assert axis.unit == ref_axis.unit, ( 437 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 438 f" but {axis.unit}!={ref_axis.unit}" 439 ) 440 if ref_size is None: 441 if isinstance(ref_axis.size, (int, float)): 442 ref_size = ref_axis.size 443 elif isinstance(ref_axis.size, ParameterizedSize): 444 ref_size = ref_axis.size.get_size(n) 445 elif isinstance(ref_axis.size, DataDependentSize): 446 raise ValueError( 447 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 448 ) 449 elif isinstance(ref_axis.size, SizeReference): 450 raise ValueError( 451 "Reference axis referenced in `SizeReference` may not be sized by a" 452 + " `SizeReference` itself." 453 ) 454 else: 455 assert_never(ref_axis.size) 456 457 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
476class AxisBase(NodeWithExplicitlySetFields): 477 id: AxisId 478 """An axis id unique across all axes of one tensor.""" 479 480 description: Annotated[str, MaxLen(128)] = ""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
483class WithHalo(Node): 484 halo: Annotated[int, Ge(1)] 485 """The halo should be cropped from the output tensor to avoid boundary effects. 486 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 487 To document a halo that is already cropped by the model use `size.offset` instead.""" 488 489 size: Annotated[ 490 SizeReference, 491 Field( 492 examples=[ 493 10, 494 SizeReference( 495 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 496 ).model_dump(mode="json"), 497 ] 498 ), 499 ] 500 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
506class BatchAxis(AxisBase): 507 implemented_type: ClassVar[Literal["batch"]] = "batch" 508 if TYPE_CHECKING: 509 type: Literal["batch"] = "batch" 510 else: 511 type: Literal["batch"] 512 513 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 514 size: Optional[Literal[1]] = None 515 """The batch size may be fixed to 1, 516 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 517 518 @property 519 def scale(self): 520 return 1.0 521 522 @property 523 def concatenable(self): 524 return True 525 526 @property 527 def unit(self): 528 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
531class ChannelAxis(AxisBase): 532 implemented_type: ClassVar[Literal["channel"]] = "channel" 533 if TYPE_CHECKING: 534 type: Literal["channel"] = "channel" 535 else: 536 type: Literal["channel"] 537 538 id: NonBatchAxisId = AxisId("channel") 539 channel_names: NotEmpty[List[Identifier]] 540 541 @property 542 def size(self) -> int: 543 return len(self.channel_names) 544 545 @property 546 def concatenable(self): 547 return False 548 549 @property 550 def scale(self) -> float: 551 return 1.0 552 553 @property 554 def unit(self): 555 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
558class IndexAxisBase(AxisBase): 559 implemented_type: ClassVar[Literal["index"]] = "index" 560 if TYPE_CHECKING: 561 type: Literal["index"] = "index" 562 else: 563 type: Literal["index"] 564 565 id: NonBatchAxisId = AxisId("index") 566 567 @property 568 def scale(self) -> float: 569 return 1.0 570 571 @property 572 def unit(self): 573 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
596class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 597 concatenable: bool = False 598 """If a model has a `concatenable` input axis, it can be processed blockwise, 599 splitting a longer sample axis into blocks matching its input tensor description. 600 Output axes are concatenable if they have a `SizeReference` to a concatenable 601 input axis. 602 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
605class IndexOutputAxis(IndexAxisBase): 606 size: Annotated[ 607 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 608 Field( 609 examples=[ 610 10, 611 SizeReference( 612 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 613 ).model_dump(mode="json"), 614 ] 615 ), 616 ] 617 """The size/length of this axis can be specified as 618 - fixed integer 619 - reference to another axis with an optional offset (`SizeReference`) 620 - data dependent size using `DataDependentSize` (size is only known after model inference) 621 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
624class TimeAxisBase(AxisBase): 625 implemented_type: ClassVar[Literal["time"]] = "time" 626 if TYPE_CHECKING: 627 type: Literal["time"] = "time" 628 else: 629 type: Literal["time"] 630 631 id: NonBatchAxisId = AxisId("time") 632 unit: Optional[TimeUnit] = None 633 scale: Annotated[float, Gt(0)] = 1.0
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
636class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 637 concatenable: bool = False 638 """If a model has a `concatenable` input axis, it can be processed blockwise, 639 splitting a longer sample axis into blocks matching its input tensor description. 640 Output axes are concatenable if they have a `SizeReference` to a concatenable 641 input axis. 642 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
645class SpaceAxisBase(AxisBase): 646 implemented_type: ClassVar[Literal["space"]] = "space" 647 if TYPE_CHECKING: 648 type: Literal["space"] = "space" 649 else: 650 type: Literal["space"] 651 652 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 653 unit: Optional[SpaceUnit] = None 654 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
657class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 658 concatenable: bool = False 659 """If a model has a `concatenable` input axis, it can be processed blockwise, 660 splitting a longer sample axis into blocks matching its input tensor description. 661 Output axes are concatenable if they have a `SizeReference` to a concatenable 662 input axis. 663 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
785class NominalOrOrdinalDataDescr(Node): 786 values: TVs 787 """A fixed set of nominal or an ascending sequence of ordinal values. 788 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 789 String `values` are interpreted as labels for tensor values 0, ..., N. 790 Note: as YAML 1.2 does not natively support a "set" datatype, 791 nominal values should be given as a sequence (aka list/array) as well. 792 """ 793 794 type: Annotated[ 795 NominalOrOrdinalDType, 796 Field( 797 examples=[ 798 "float32", 799 "uint8", 800 "uint16", 801 "int64", 802 "bool", 803 ], 804 ), 805 ] = "uint8" 806 807 @model_validator(mode="after") 808 def _validate_values_match_type( 809 self, 810 ) -> Self: 811 incompatible: List[Any] = [] 812 for v in self.values: 813 if self.type == "bool": 814 if not isinstance(v, bool): 815 incompatible.append(v) 816 elif self.type in DTYPE_LIMITS: 817 if ( 818 isinstance(v, (int, float)) 819 and ( 820 v < DTYPE_LIMITS[self.type].min 821 or v > DTYPE_LIMITS[self.type].max 822 ) 823 or (isinstance(v, str) and "uint" not in self.type) 824 or (isinstance(v, float) and "int" in self.type) 825 ): 826 incompatible.append(v) 827 else: 828 incompatible.append(v) 829 830 if len(incompatible) == 5: 831 incompatible.append("...") 832 break 833 834 if incompatible: 835 raise ValueError( 836 f"data type '{self.type}' incompatible with values {incompatible}" 837 ) 838 839 return self 840 841 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 842 843 @property 844 def range(self): 845 if isinstance(self.values[0], str): 846 return 0, len(self.values) - 1 847 else: 848 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
865class IntervalOrRatioDataDescr(Node): 866 type: Annotated[ # todo: rename to dtype 867 IntervalOrRatioDType, 868 Field( 869 examples=["float32", "float64", "uint8", "uint16"], 870 ), 871 ] = "float32" 872 range: Tuple[Optional[float], Optional[float]] = ( 873 None, 874 None, 875 ) 876 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 877 `None` corresponds to min/max of what can be expressed by **type**.""" 878 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 879 scale: float = 1.0 880 """Scale for data on an interval (or ratio) scale.""" 881 offset: Optional[float] = None 882 """Offset for data on a ratio scale.""" 883 884 @model_validator(mode="before") 885 def _replace_inf(cls, data: Any): 886 if is_dict(data): 887 if "range" in data and is_sequence(data["range"]): 888 forbidden = ( 889 "inf", 890 "-inf", 891 ".inf", 892 "-.inf", 893 float("inf"), 894 float("-inf"), 895 ) 896 if any(v in forbidden for v in data["range"]): 897 issue_warning("replaced 'inf' value", value=data["range"]) 898 899 data["range"] = tuple( 900 (None if v in forbidden else v) for v in data["range"] 901 ) 902 903 return data
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
None
corresponds to min/max of what can be expressed by type.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
processing base class
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
913class BinarizeKwargs(ProcessingKwargs): 914 """key word arguments for `BinarizeDescr`""" 915 916 threshold: float 917 """The fixed threshold"""
key word arguments for BinarizeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
920class BinarizeAlongAxisKwargs(ProcessingKwargs): 921 """key word arguments for `BinarizeDescr`""" 922 923 threshold: NotEmpty[List[float]] 924 """The fixed threshold values along `axis`""" 925 926 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 927 """The `threshold` axis"""
key word arguments for BinarizeDescr
The threshold
axis
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
930class BinarizeDescr(ProcessingDescrBase): 931 """Binarize the tensor with a fixed threshold. 932 933 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 934 will be set to one, values below the threshold to zero. 935 936 Examples: 937 - in YAML 938 ```yaml 939 postprocessing: 940 - id: binarize 941 kwargs: 942 axis: 'channel' 943 threshold: [0.25, 0.5, 0.75] 944 ``` 945 - in Python: 946 >>> postprocessing = [BinarizeDescr( 947 ... kwargs=BinarizeAlongAxisKwargs( 948 ... axis=AxisId('channel'), 949 ... threshold=[0.25, 0.5, 0.75], 950 ... ) 951 ... )] 952 """ 953 954 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 955 if TYPE_CHECKING: 956 id: Literal["binarize"] = "binarize" 957 else: 958 id: Literal["binarize"] 959 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
962class ClipDescr(ProcessingDescrBase): 963 """Set tensor values below min to min and above max to max. 964 965 See `ScaleRangeDescr` for examples. 966 """ 967 968 implemented_id: ClassVar[Literal["clip"]] = "clip" 969 if TYPE_CHECKING: 970 id: Literal["clip"] = "clip" 971 else: 972 id: Literal["clip"] 973 974 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
977class EnsureDtypeKwargs(ProcessingKwargs): 978 """key word arguments for `EnsureDtypeDescr`""" 979 980 dtype: Literal[ 981 "float32", 982 "float64", 983 "uint8", 984 "int8", 985 "uint16", 986 "int16", 987 "uint32", 988 "int32", 989 "uint64", 990 "int64", 991 "bool", 992 ]
key word arguments for EnsureDtypeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
995class EnsureDtypeDescr(ProcessingDescrBase): 996 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 997 998 This can for example be used to ensure the inner neural network model gets a 999 different input tensor data type than the fully described bioimage.io model does. 1000 1001 Examples: 1002 The described bioimage.io model (incl. preprocessing) accepts any 1003 float32-compatible tensor, normalizes it with percentiles and clipping and then 1004 casts it to uint8, which is what the neural network in this example expects. 1005 - in YAML 1006 ```yaml 1007 inputs: 1008 - data: 1009 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1010 preprocessing: 1011 - id: scale_range 1012 kwargs: 1013 axes: ['y', 'x'] 1014 max_percentile: 99.8 1015 min_percentile: 5.0 1016 - id: clip 1017 kwargs: 1018 min: 0.0 1019 max: 1.0 1020 - id: ensure_dtype # the neural network of the model requires uint8 1021 kwargs: 1022 dtype: uint8 1023 ``` 1024 - in Python: 1025 >>> preprocessing = [ 1026 ... ScaleRangeDescr( 1027 ... kwargs=ScaleRangeKwargs( 1028 ... axes= (AxisId('y'), AxisId('x')), 1029 ... max_percentile= 99.8, 1030 ... min_percentile= 5.0, 1031 ... ) 1032 ... ), 1033 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1034 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1035 ... ] 1036 """ 1037 1038 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1039 if TYPE_CHECKING: 1040 id: Literal["ensure_dtype"] = "ensure_dtype" 1041 else: 1042 id: Literal["ensure_dtype"] 1043 1044 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype # the neural network of the model requires uint8 kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1047class ScaleLinearKwargs(ProcessingKwargs): 1048 """Key word arguments for `ScaleLinearDescr`""" 1049 1050 gain: float = 1.0 1051 """multiplicative factor""" 1052 1053 offset: float = 0.0 1054 """additive term""" 1055 1056 @model_validator(mode="after") 1057 def _validate(self) -> Self: 1058 if self.gain == 1.0 and self.offset == 0.0: 1059 raise ValueError( 1060 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1061 + " != 0.0." 1062 ) 1063 1064 return self
Key word arguments for ScaleLinearDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1067class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1068 """Key word arguments for `ScaleLinearDescr`""" 1069 1070 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1071 """The axis of gain and offset values.""" 1072 1073 gain: Union[float, NotEmpty[List[float]]] = 1.0 1074 """multiplicative factor""" 1075 1076 offset: Union[float, NotEmpty[List[float]]] = 0.0 1077 """additive term""" 1078 1079 @model_validator(mode="after") 1080 def _validate(self) -> Self: 1081 1082 if isinstance(self.gain, list): 1083 if isinstance(self.offset, list): 1084 if len(self.gain) != len(self.offset): 1085 raise ValueError( 1086 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1087 ) 1088 else: 1089 self.offset = [float(self.offset)] * len(self.gain) 1090 elif isinstance(self.offset, list): 1091 self.gain = [float(self.gain)] * len(self.offset) 1092 else: 1093 raise ValueError( 1094 "Do not specify an `axis` for scalar gain and offset values." 1095 ) 1096 1097 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1098 raise ValueError( 1099 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1100 + " != 0.0." 1101 ) 1102 1103 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1106class ScaleLinearDescr(ProcessingDescrBase): 1107 """Fixed linear scaling. 1108 1109 Examples: 1110 1. Scale with scalar gain and offset 1111 - in YAML 1112 ```yaml 1113 preprocessing: 1114 - id: scale_linear 1115 kwargs: 1116 gain: 2.0 1117 offset: 3.0 1118 ``` 1119 - in Python: 1120 >>> preprocessing = [ 1121 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1122 ... ] 1123 1124 2. Independent scaling along an axis 1125 - in YAML 1126 ```yaml 1127 preprocessing: 1128 - id: scale_linear 1129 kwargs: 1130 axis: 'channel' 1131 gain: [1.0, 2.0, 3.0] 1132 ``` 1133 - in Python: 1134 >>> preprocessing = [ 1135 ... ScaleLinearDescr( 1136 ... kwargs=ScaleLinearAlongAxisKwargs( 1137 ... axis=AxisId("channel"), 1138 ... gain=[1.0, 2.0, 3.0], 1139 ... ) 1140 ... ) 1141 ... ] 1142 1143 """ 1144 1145 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1146 if TYPE_CHECKING: 1147 id: Literal["scale_linear"] = "scale_linear" 1148 else: 1149 id: Literal["scale_linear"] 1150 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1153class SigmoidDescr(ProcessingDescrBase): 1154 """The logistic sigmoid funciton, a.k.a. expit function. 1155 1156 Examples: 1157 - in YAML 1158 ```yaml 1159 postprocessing: 1160 - id: sigmoid 1161 ``` 1162 - in Python: 1163 >>> postprocessing = [SigmoidDescr()] 1164 """ 1165 1166 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1167 if TYPE_CHECKING: 1168 id: Literal["sigmoid"] = "sigmoid" 1169 else: 1170 id: Literal["sigmoid"] 1171 1172 @property 1173 def kwargs(self) -> ProcessingKwargs: 1174 """empty kwargs""" 1175 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1172 @property 1173 def kwargs(self) -> ProcessingKwargs: 1174 """empty kwargs""" 1175 return ProcessingKwargs()
empty kwargs
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1178class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1179 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1180 1181 mean: float 1182 """The mean value to normalize with.""" 1183 1184 std: Annotated[float, Ge(1e-6)] 1185 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1188class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1189 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1190 1191 mean: NotEmpty[List[float]] 1192 """The mean value(s) to normalize with.""" 1193 1194 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1195 """The standard deviation value(s) to normalize with. 1196 Size must match `mean` values.""" 1197 1198 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1199 """The axis of the mean/std values to normalize each entry along that dimension 1200 separately.""" 1201 1202 @model_validator(mode="after") 1203 def _mean_and_std_match(self) -> Self: 1204 if len(self.mean) != len(self.std): 1205 raise ValueError( 1206 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1207 + " must match." 1208 ) 1209 1210 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1213class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1214 """Subtract a given mean and divide by the standard deviation. 1215 1216 Normalize with fixed, precomputed values for 1217 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1218 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1219 axes. 1220 1221 Examples: 1222 1. scalar value for whole tensor 1223 - in YAML 1224 ```yaml 1225 preprocessing: 1226 - id: fixed_zero_mean_unit_variance 1227 kwargs: 1228 mean: 103.5 1229 std: 13.7 1230 ``` 1231 - in Python 1232 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1233 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1234 ... )] 1235 1236 2. independently along an axis 1237 - in YAML 1238 ```yaml 1239 preprocessing: 1240 - id: fixed_zero_mean_unit_variance 1241 kwargs: 1242 axis: channel 1243 mean: [101.5, 102.5, 103.5] 1244 std: [11.7, 12.7, 13.7] 1245 ``` 1246 - in Python 1247 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1248 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1249 ... axis=AxisId("channel"), 1250 ... mean=[101.5, 102.5, 103.5], 1251 ... std=[11.7, 12.7, 13.7], 1252 ... ) 1253 ... )] 1254 """ 1255 1256 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1257 "fixed_zero_mean_unit_variance" 1258 ) 1259 if TYPE_CHECKING: 1260 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1261 else: 1262 id: Literal["fixed_zero_mean_unit_variance"] 1263 1264 kwargs: Union[ 1265 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1266 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1269class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1270 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1271 1272 axes: Annotated[ 1273 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1274 ] = None 1275 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1276 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1277 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1278 To normalize each sample independently leave out the 'batch' axis. 1279 Default: Scale all axes jointly.""" 1280 1281 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1282 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1285class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1286 """Subtract mean and divide by variance. 1287 1288 Examples: 1289 Subtract tensor mean and variance 1290 - in YAML 1291 ```yaml 1292 preprocessing: 1293 - id: zero_mean_unit_variance 1294 ``` 1295 - in Python 1296 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1297 """ 1298 1299 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1300 "zero_mean_unit_variance" 1301 ) 1302 if TYPE_CHECKING: 1303 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1304 else: 1305 id: Literal["zero_mean_unit_variance"] 1306 1307 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1308 default_factory=ZeroMeanUnitVarianceKwargs 1309 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1312class ScaleRangeKwargs(ProcessingKwargs): 1313 """key word arguments for `ScaleRangeDescr` 1314 1315 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1316 this processing step normalizes data to the [0, 1] intervall. 1317 For other percentiles the normalized values will partially be outside the [0, 1] 1318 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1319 normalized values to a range. 1320 """ 1321 1322 axes: Annotated[ 1323 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1324 ] = None 1325 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1326 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1327 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1328 To normalize samples independently, leave out the "batch" axis. 1329 Default: Scale all axes jointly.""" 1330 1331 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1332 """The lower percentile used to determine the value to align with zero.""" 1333 1334 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1335 """The upper percentile used to determine the value to align with one. 1336 Has to be bigger than `min_percentile`. 1337 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1338 accepting percentiles specified in the range 0.0 to 1.0.""" 1339 1340 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1341 """Epsilon for numeric stability. 1342 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1343 with `v_lower,v_upper` values at the respective percentiles.""" 1344 1345 reference_tensor: Optional[TensorId] = None 1346 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1347 For any tensor in `inputs` only input tensor references are allowed.""" 1348 1349 @field_validator("max_percentile", mode="after") 1350 @classmethod 1351 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1352 if (min_p := info.data["min_percentile"]) >= value: 1353 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1354 1355 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1349 @field_validator("max_percentile", mode="after") 1350 @classmethod 1351 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1352 if (min_p := info.data["min_percentile"]) >= value: 1353 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1354 1355 return value
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1358class ScaleRangeDescr(ProcessingDescrBase): 1359 """Scale with percentiles. 1360 1361 Examples: 1362 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1363 - in YAML 1364 ```yaml 1365 preprocessing: 1366 - id: scale_range 1367 kwargs: 1368 axes: ['y', 'x'] 1369 max_percentile: 99.8 1370 min_percentile: 5.0 1371 ``` 1372 - in Python 1373 >>> preprocessing = [ 1374 ... ScaleRangeDescr( 1375 ... kwargs=ScaleRangeKwargs( 1376 ... axes= (AxisId('y'), AxisId('x')), 1377 ... max_percentile= 99.8, 1378 ... min_percentile= 5.0, 1379 ... ) 1380 ... ), 1381 ... ClipDescr( 1382 ... kwargs=ClipKwargs( 1383 ... min=0.0, 1384 ... max=1.0, 1385 ... ) 1386 ... ), 1387 ... ] 1388 1389 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1390 - in YAML 1391 ```yaml 1392 preprocessing: 1393 - id: scale_range 1394 kwargs: 1395 axes: ['y', 'x'] 1396 max_percentile: 99.8 1397 min_percentile: 5.0 1398 - id: scale_range 1399 - id: clip 1400 kwargs: 1401 min: 0.0 1402 max: 1.0 1403 ``` 1404 - in Python 1405 >>> preprocessing = [ScaleRangeDescr( 1406 ... kwargs=ScaleRangeKwargs( 1407 ... axes= (AxisId('y'), AxisId('x')), 1408 ... max_percentile= 99.8, 1409 ... min_percentile= 5.0, 1410 ... ) 1411 ... )] 1412 1413 """ 1414 1415 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1416 if TYPE_CHECKING: 1417 id: Literal["scale_range"] = "scale_range" 1418 else: 1419 id: Literal["scale_range"] 1420 kwargs: ScaleRangeKwargs
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1423class ScaleMeanVarianceKwargs(ProcessingKwargs): 1424 """key word arguments for `ScaleMeanVarianceKwargs`""" 1425 1426 reference_tensor: TensorId 1427 """Name of tensor to match.""" 1428 1429 axes: Annotated[ 1430 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1431 ] = None 1432 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1433 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1434 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1435 To normalize samples independently, leave out the 'batch' axis. 1436 Default: Scale all axes jointly.""" 1437 1438 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1439 """Epsilon for numeric stability: 1440 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
Epsilon for numeric stability:
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1443class ScaleMeanVarianceDescr(ProcessingDescrBase): 1444 """Scale a tensor's data distribution to match another tensor's mean/std. 1445 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1446 """ 1447 1448 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1449 if TYPE_CHECKING: 1450 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1451 else: 1452 id: Literal["scale_mean_variance"] 1453 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1487class TensorDescrBase(Node, Generic[IO_AxisT]): 1488 id: TensorId 1489 """Tensor id. No duplicates are allowed.""" 1490 1491 description: Annotated[str, MaxLen(128)] = "" 1492 """free text description""" 1493 1494 axes: NotEmpty[Sequence[IO_AxisT]] 1495 """tensor axes""" 1496 1497 @property 1498 def shape(self): 1499 return tuple(a.size for a in self.axes) 1500 1501 @field_validator("axes", mode="after", check_fields=False) 1502 @classmethod 1503 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1504 batch_axes = [a for a in axes if a.type == "batch"] 1505 if len(batch_axes) > 1: 1506 raise ValueError( 1507 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1508 ) 1509 1510 seen_ids: Set[AxisId] = set() 1511 duplicate_axes_ids: Set[AxisId] = set() 1512 for a in axes: 1513 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1514 1515 if duplicate_axes_ids: 1516 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1517 1518 return axes 1519 1520 test_tensor: FileDescr_ 1521 """An example tensor to use for testing. 1522 Using the model with the test input tensors is expected to yield the test output tensors. 1523 Each test tensor has be a an ndarray in the 1524 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1525 The file extension must be '.npy'.""" 1526 1527 sample_tensor: Optional[FileDescr_] = None 1528 """A sample tensor to illustrate a possible input/output for the model, 1529 The sample image primarily serves to inform a human user about an example use case 1530 and is typically stored as .hdf5, .png or .tiff. 1531 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1532 (numpy's `.npy` format is not supported). 1533 The image dimensionality has to match the number of axes specified in this tensor description. 1534 """ 1535 1536 @model_validator(mode="after") 1537 def _validate_sample_tensor(self) -> Self: 1538 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1539 return self 1540 1541 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1542 tensor: NDArray[Any] = imread( 1543 reader.read(), 1544 extension=PurePosixPath(reader.original_file_name).suffix, 1545 ) 1546 n_dims = len(tensor.squeeze().shape) 1547 n_dims_min = n_dims_max = len(self.axes) 1548 1549 for a in self.axes: 1550 if isinstance(a, BatchAxis): 1551 n_dims_min -= 1 1552 elif isinstance(a.size, int): 1553 if a.size == 1: 1554 n_dims_min -= 1 1555 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1556 if a.size.min == 1: 1557 n_dims_min -= 1 1558 elif isinstance(a.size, SizeReference): 1559 if a.size.offset < 2: 1560 # size reference may result in singleton axis 1561 n_dims_min -= 1 1562 else: 1563 assert_never(a.size) 1564 1565 n_dims_min = max(0, n_dims_min) 1566 if n_dims < n_dims_min or n_dims > n_dims_max: 1567 raise ValueError( 1568 f"Expected sample tensor to have {n_dims_min} to" 1569 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1570 ) 1571 1572 return self 1573 1574 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1575 IntervalOrRatioDataDescr() 1576 ) 1577 """Description of the tensor's data values, optionally per channel. 1578 If specified per channel, the data `type` needs to match across channels.""" 1579 1580 @property 1581 def dtype( 1582 self, 1583 ) -> Literal[ 1584 "float32", 1585 "float64", 1586 "uint8", 1587 "int8", 1588 "uint16", 1589 "int16", 1590 "uint32", 1591 "int32", 1592 "uint64", 1593 "int64", 1594 "bool", 1595 ]: 1596 """dtype as specified under `data.type` or `data[i].type`""" 1597 if isinstance(self.data, collections.abc.Sequence): 1598 return self.data[0].type 1599 else: 1600 return self.data.type 1601 1602 @field_validator("data", mode="after") 1603 @classmethod 1604 def _check_data_type_across_channels( 1605 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1606 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1607 if not isinstance(value, list): 1608 return value 1609 1610 dtypes = {t.type for t in value} 1611 if len(dtypes) > 1: 1612 raise ValueError( 1613 "Tensor data descriptions per channel need to agree in their data" 1614 + f" `type`, but found {dtypes}." 1615 ) 1616 1617 return value 1618 1619 @model_validator(mode="after") 1620 def _check_data_matches_channelaxis(self) -> Self: 1621 if not isinstance(self.data, (list, tuple)): 1622 return self 1623 1624 for a in self.axes: 1625 if isinstance(a, ChannelAxis): 1626 size = a.size 1627 assert isinstance(size, int) 1628 break 1629 else: 1630 return self 1631 1632 if len(self.data) != size: 1633 raise ValueError( 1634 f"Got tensor data descriptions for {len(self.data)} channels, but" 1635 + f" '{a.id}' axis has size {size}." 1636 ) 1637 1638 return self 1639 1640 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1641 if len(array.shape) != len(self.axes): 1642 raise ValueError( 1643 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1644 + f" incompatible with {len(self.axes)} axes." 1645 ) 1646 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1580 @property 1581 def dtype( 1582 self, 1583 ) -> Literal[ 1584 "float32", 1585 "float64", 1586 "uint8", 1587 "int8", 1588 "uint16", 1589 "int16", 1590 "uint32", 1591 "int32", 1592 "uint64", 1593 "int64", 1594 "bool", 1595 ]: 1596 """dtype as specified under `data.type` or `data[i].type`""" 1597 if isinstance(self.data, collections.abc.Sequence): 1598 return self.data[0].type 1599 else: 1600 return self.data.type
dtype as specified under data.type
or data[i].type
1640 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1641 if len(array.shape) != len(self.axes): 1642 raise ValueError( 1643 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1644 + f" incompatible with {len(self.axes)} axes." 1645 ) 1646 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Inherited Members
1649class InputTensorDescr(TensorDescrBase[InputAxis]): 1650 id: TensorId = TensorId("input") 1651 """Input tensor id. 1652 No duplicates are allowed across all inputs and outputs.""" 1653 1654 optional: bool = False 1655 """indicates that this tensor may be `None`""" 1656 1657 preprocessing: List[PreprocessingDescr] = Field( 1658 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1659 ) 1660 1661 """Description of how this input should be preprocessed. 1662 1663 notes: 1664 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1665 to ensure an input tensor's data type matches the input tensor's data description. 1666 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1667 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1668 changing the data type. 1669 """ 1670 1671 @model_validator(mode="after") 1672 def _validate_preprocessing_kwargs(self) -> Self: 1673 axes_ids = [a.id for a in self.axes] 1674 for p in self.preprocessing: 1675 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1676 if kwargs_axes is None: 1677 continue 1678 1679 if not isinstance(kwargs_axes, collections.abc.Sequence): 1680 raise ValueError( 1681 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1682 ) 1683 1684 if any(a not in axes_ids for a in kwargs_axes): 1685 raise ValueError( 1686 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1687 ) 1688 1689 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1690 dtype = self.data.type 1691 else: 1692 dtype = self.data[0].type 1693 1694 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1695 if not self.preprocessing or not isinstance( 1696 self.preprocessing[0], EnsureDtypeDescr 1697 ): 1698 self.preprocessing.insert( 1699 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1700 ) 1701 1702 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1703 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1704 self.preprocessing.append( 1705 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1706 ) 1707 1708 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
1711def convert_axes( 1712 axes: str, 1713 *, 1714 shape: Union[ 1715 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1716 ], 1717 tensor_type: Literal["input", "output"], 1718 halo: Optional[Sequence[int]], 1719 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1720): 1721 ret: List[AnyAxis] = [] 1722 for i, a in enumerate(axes): 1723 axis_type = _AXIS_TYPE_MAP.get(a, a) 1724 if axis_type == "batch": 1725 ret.append(BatchAxis()) 1726 continue 1727 1728 scale = 1.0 1729 if isinstance(shape, _ParameterizedInputShape_v0_4): 1730 if shape.step[i] == 0: 1731 size = shape.min[i] 1732 else: 1733 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1734 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1735 ref_t = str(shape.reference_tensor) 1736 if ref_t.count(".") == 1: 1737 t_id, orig_a_id = ref_t.split(".") 1738 else: 1739 t_id = ref_t 1740 orig_a_id = a 1741 1742 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1743 if not (orig_scale := shape.scale[i]): 1744 # old way to insert a new axis dimension 1745 size = int(2 * shape.offset[i]) 1746 else: 1747 scale = 1 / orig_scale 1748 if axis_type in ("channel", "index"): 1749 # these axes no longer have a scale 1750 offset_from_scale = orig_scale * size_refs.get( 1751 _TensorName_v0_4(t_id), {} 1752 ).get(orig_a_id, 0) 1753 else: 1754 offset_from_scale = 0 1755 size = SizeReference( 1756 tensor_id=TensorId(t_id), 1757 axis_id=AxisId(a_id), 1758 offset=int(offset_from_scale + 2 * shape.offset[i]), 1759 ) 1760 else: 1761 size = shape[i] 1762 1763 if axis_type == "time": 1764 if tensor_type == "input": 1765 ret.append(TimeInputAxis(size=size, scale=scale)) 1766 else: 1767 assert not isinstance(size, ParameterizedSize) 1768 if halo is None: 1769 ret.append(TimeOutputAxis(size=size, scale=scale)) 1770 else: 1771 assert not isinstance(size, int) 1772 ret.append( 1773 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1774 ) 1775 1776 elif axis_type == "index": 1777 if tensor_type == "input": 1778 ret.append(IndexInputAxis(size=size)) 1779 else: 1780 if isinstance(size, ParameterizedSize): 1781 size = DataDependentSize(min=size.min) 1782 1783 ret.append(IndexOutputAxis(size=size)) 1784 elif axis_type == "channel": 1785 assert not isinstance(size, ParameterizedSize) 1786 if isinstance(size, SizeReference): 1787 warnings.warn( 1788 "Conversion of channel size from an implicit output shape may be" 1789 + " wrong" 1790 ) 1791 ret.append( 1792 ChannelAxis( 1793 channel_names=[ 1794 Identifier(f"channel{i}") for i in range(size.offset) 1795 ] 1796 ) 1797 ) 1798 else: 1799 ret.append( 1800 ChannelAxis( 1801 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1802 ) 1803 ) 1804 elif axis_type == "space": 1805 if tensor_type == "input": 1806 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1807 else: 1808 assert not isinstance(size, ParameterizedSize) 1809 if halo is None or halo[i] == 0: 1810 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1811 elif isinstance(size, int): 1812 raise NotImplementedError( 1813 f"output axis with halo and fixed size (here {size}) not allowed" 1814 ) 1815 else: 1816 ret.append( 1817 SpaceOutputAxisWithHalo( 1818 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1819 ) 1820 ) 1821 1822 return ret
1982class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1983 id: TensorId = TensorId("output") 1984 """Output tensor id. 1985 No duplicates are allowed across all inputs and outputs.""" 1986 1987 postprocessing: List[PostprocessingDescr] = Field( 1988 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 1989 ) 1990 """Description of how this output should be postprocessed. 1991 1992 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1993 If not given this is added to cast to this tensor's `data.type`. 1994 """ 1995 1996 @model_validator(mode="after") 1997 def _validate_postprocessing_kwargs(self) -> Self: 1998 axes_ids = [a.id for a in self.axes] 1999 for p in self.postprocessing: 2000 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2001 if kwargs_axes is None: 2002 continue 2003 2004 if not isinstance(kwargs_axes, collections.abc.Sequence): 2005 raise ValueError( 2006 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2007 ) 2008 2009 if any(a not in axes_ids for a in kwargs_axes): 2010 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2011 2012 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2013 dtype = self.data.type 2014 else: 2015 dtype = self.data[0].type 2016 2017 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2018 if not self.postprocessing or not isinstance( 2019 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2020 ): 2021 self.postprocessing.append( 2022 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2023 ) 2024 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
2074def validate_tensors( 2075 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2076 tensor_origin: Literal[ 2077 "test_tensor" 2078 ], # for more precise error messages, e.g. 'test_tensor' 2079): 2080 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2081 2082 def e_msg(d: TensorDescr): 2083 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2084 2085 for descr, array in tensors.values(): 2086 try: 2087 axis_sizes = descr.get_axis_sizes_for_array(array) 2088 except ValueError as e: 2089 raise ValueError(f"{e_msg(descr)} {e}") 2090 else: 2091 all_tensor_axes[descr.id] = { 2092 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2093 } 2094 2095 for descr, array in tensors.values(): 2096 if descr.dtype in ("float32", "float64"): 2097 invalid_test_tensor_dtype = array.dtype.name not in ( 2098 "float32", 2099 "float64", 2100 "uint8", 2101 "int8", 2102 "uint16", 2103 "int16", 2104 "uint32", 2105 "int32", 2106 "uint64", 2107 "int64", 2108 ) 2109 else: 2110 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2111 2112 if invalid_test_tensor_dtype: 2113 raise ValueError( 2114 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2115 + f" match described dtype '{descr.dtype}'" 2116 ) 2117 2118 if array.min() > -1e-4 and array.max() < 1e-4: 2119 raise ValueError( 2120 "Output values are too small for reliable testing." 2121 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2122 ) 2123 2124 for a in descr.axes: 2125 actual_size = all_tensor_axes[descr.id][a.id][1] 2126 if a.size is None: 2127 continue 2128 2129 if isinstance(a.size, int): 2130 if actual_size != a.size: 2131 raise ValueError( 2132 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2133 + f"has incompatible size {actual_size}, expected {a.size}" 2134 ) 2135 elif isinstance(a.size, ParameterizedSize): 2136 _ = a.size.validate_size(actual_size) 2137 elif isinstance(a.size, DataDependentSize): 2138 _ = a.size.validate_size(actual_size) 2139 elif isinstance(a.size, SizeReference): 2140 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2141 if ref_tensor_axes is None: 2142 raise ValueError( 2143 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2144 + f" reference '{a.size.tensor_id}'" 2145 ) 2146 2147 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2148 if ref_axis is None or ref_size is None: 2149 raise ValueError( 2150 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2151 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2152 ) 2153 2154 if a.unit != ref_axis.unit: 2155 raise ValueError( 2156 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2157 + " axis and reference axis to have the same `unit`, but" 2158 + f" {a.unit}!={ref_axis.unit}" 2159 ) 2160 2161 if actual_size != ( 2162 expected_size := ( 2163 ref_size * ref_axis.scale / a.scale + a.size.offset 2164 ) 2165 ): 2166 raise ValueError( 2167 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2168 + f" {actual_size} invalid for referenced size {ref_size};" 2169 + f" expected {expected_size}" 2170 ) 2171 else: 2172 assert_never(a.size)
2192class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2193 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2194 """Architecture source file""" 2195 2196 @model_serializer(mode="wrap", when_used="unless-none") 2197 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2198 return package_file_descr_serializer(self, nxt, info)
A file description
Architecture source file
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2201class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2202 import_from: str 2203 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2263class WeightsEntryDescrBase(FileDescr): 2264 type: ClassVar[WeightsFormat] 2265 weights_format_name: ClassVar[str] # human readable 2266 2267 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2268 """Source of the weights file.""" 2269 2270 authors: Optional[List[Author]] = None 2271 """Authors 2272 Either the person(s) that have trained this model resulting in the original weights file. 2273 (If this is the initial weights entry, i.e. it does not have a `parent`) 2274 Or the person(s) who have converted the weights to this weights format. 2275 (If this is a child weight, i.e. it has a `parent` field) 2276 """ 2277 2278 parent: Annotated[ 2279 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2280 ] = None 2281 """The source weights these weights were converted from. 2282 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2283 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2284 All weight entries except one (the initial set of weights resulting from training the model), 2285 need to have this field.""" 2286 2287 comment: str = "" 2288 """A comment about this weights entry, for example how these weights were created.""" 2289 2290 @model_validator(mode="after") 2291 def _validate(self) -> Self: 2292 if self.type == self.parent: 2293 raise ValueError("Weights entry can't be it's own parent.") 2294 2295 return self 2296 2297 @model_serializer(mode="wrap", when_used="unless-none") 2298 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2299 return package_file_descr_serializer(self, nxt, info)
A file description
Source of the weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2302class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2303 type = "keras_hdf5" 2304 weights_format_name: ClassVar[str] = "Keras HDF5" 2305 tensorflow_version: Version 2306 """TensorFlow version used to create these weights."""
A file description
TensorFlow version used to create these weights.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2309class OnnxWeightsDescr(WeightsEntryDescrBase): 2310 type = "onnx" 2311 weights_format_name: ClassVar[str] = "ONNX" 2312 opset_version: Annotated[int, Ge(7)] 2313 """ONNX opset version"""
A file description
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2316class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2317 type = "pytorch_state_dict" 2318 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2319 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2320 pytorch_version: Version 2321 """Version of the PyTorch library used. 2322 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2323 """ 2324 dependencies: Optional[FileDescr_dependencies] = None 2325 """Custom depencies beyond pytorch described in a Conda environment file. 2326 Allows to specify custom dependencies, see conda docs: 2327 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2328 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2329 2330 The conda environment file should include pytorch and any version pinning has to be compatible with 2331 **pytorch_version**. 2332 """
A file description
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:
The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2335class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2336 type = "tensorflow_js" 2337 weights_format_name: ClassVar[str] = "Tensorflow.js" 2338 tensorflow_version: Version 2339 """Version of the TensorFlow library used.""" 2340 2341 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2342 """The multi-file weights. 2343 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2346class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2347 type = "tensorflow_saved_model_bundle" 2348 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2349 tensorflow_version: Version 2350 """Version of the TensorFlow library used.""" 2351 2352 dependencies: Optional[FileDescr_dependencies] = None 2353 """Custom dependencies beyond tensorflow. 2354 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2355 2356 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2357 """The multi-file weights. 2358 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2361class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2362 type = "torchscript" 2363 weights_format_name: ClassVar[str] = "TorchScript" 2364 pytorch_version: Version 2365 """Version of the PyTorch library used."""
A file description
Version of the PyTorch library used.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2368class WeightsDescr(Node): 2369 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2370 onnx: Optional[OnnxWeightsDescr] = None 2371 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2372 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2373 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2374 None 2375 ) 2376 torchscript: Optional[TorchscriptWeightsDescr] = None 2377 2378 @model_validator(mode="after") 2379 def check_entries(self) -> Self: 2380 entries = {wtype for wtype, entry in self if entry is not None} 2381 2382 if not entries: 2383 raise ValueError("Missing weights entry") 2384 2385 entries_wo_parent = { 2386 wtype 2387 for wtype, entry in self 2388 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2389 } 2390 if len(entries_wo_parent) != 1: 2391 issue_warning( 2392 "Exactly one weights entry may not specify the `parent` field (got" 2393 + " {value}). That entry is considered the original set of model weights." 2394 + " Other weight formats are created through conversion of the orignal or" 2395 + " already converted weights. They have to reference the weights format" 2396 + " they were converted from as their `parent`.", 2397 value=len(entries_wo_parent), 2398 field="weights", 2399 ) 2400 2401 for wtype, entry in self: 2402 if entry is None: 2403 continue 2404 2405 assert hasattr(entry, "type") 2406 assert hasattr(entry, "parent") 2407 assert wtype == entry.type 2408 if ( 2409 entry.parent is not None and entry.parent not in entries 2410 ): # self reference checked for `parent` field 2411 raise ValueError( 2412 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2413 + f" formats: {entries}" 2414 ) 2415 2416 return self 2417 2418 def __getitem__( 2419 self, 2420 key: Literal[ 2421 "keras_hdf5", 2422 "onnx", 2423 "pytorch_state_dict", 2424 "tensorflow_js", 2425 "tensorflow_saved_model_bundle", 2426 "torchscript", 2427 ], 2428 ): 2429 if key == "keras_hdf5": 2430 ret = self.keras_hdf5 2431 elif key == "onnx": 2432 ret = self.onnx 2433 elif key == "pytorch_state_dict": 2434 ret = self.pytorch_state_dict 2435 elif key == "tensorflow_js": 2436 ret = self.tensorflow_js 2437 elif key == "tensorflow_saved_model_bundle": 2438 ret = self.tensorflow_saved_model_bundle 2439 elif key == "torchscript": 2440 ret = self.torchscript 2441 else: 2442 raise KeyError(key) 2443 2444 if ret is None: 2445 raise KeyError(key) 2446 2447 return ret 2448 2449 @property 2450 def available_formats(self): 2451 return { 2452 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2453 **({} if self.onnx is None else {"onnx": self.onnx}), 2454 **( 2455 {} 2456 if self.pytorch_state_dict is None 2457 else {"pytorch_state_dict": self.pytorch_state_dict} 2458 ), 2459 **( 2460 {} 2461 if self.tensorflow_js is None 2462 else {"tensorflow_js": self.tensorflow_js} 2463 ), 2464 **( 2465 {} 2466 if self.tensorflow_saved_model_bundle is None 2467 else { 2468 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2469 } 2470 ), 2471 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2472 } 2473 2474 @property 2475 def missing_formats(self): 2476 return { 2477 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2478 }
2378 @model_validator(mode="after") 2379 def check_entries(self) -> Self: 2380 entries = {wtype for wtype, entry in self if entry is not None} 2381 2382 if not entries: 2383 raise ValueError("Missing weights entry") 2384 2385 entries_wo_parent = { 2386 wtype 2387 for wtype, entry in self 2388 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2389 } 2390 if len(entries_wo_parent) != 1: 2391 issue_warning( 2392 "Exactly one weights entry may not specify the `parent` field (got" 2393 + " {value}). That entry is considered the original set of model weights." 2394 + " Other weight formats are created through conversion of the orignal or" 2395 + " already converted weights. They have to reference the weights format" 2396 + " they were converted from as their `parent`.", 2397 value=len(entries_wo_parent), 2398 field="weights", 2399 ) 2400 2401 for wtype, entry in self: 2402 if entry is None: 2403 continue 2404 2405 assert hasattr(entry, "type") 2406 assert hasattr(entry, "parent") 2407 assert wtype == entry.type 2408 if ( 2409 entry.parent is not None and entry.parent not in entries 2410 ): # self reference checked for `parent` field 2411 raise ValueError( 2412 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2413 + f" formats: {entries}" 2414 ) 2415 2416 return self
2449 @property 2450 def available_formats(self): 2451 return { 2452 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2453 **({} if self.onnx is None else {"onnx": self.onnx}), 2454 **( 2455 {} 2456 if self.pytorch_state_dict is None 2457 else {"pytorch_state_dict": self.pytorch_state_dict} 2458 ), 2459 **( 2460 {} 2461 if self.tensorflow_js is None 2462 else {"tensorflow_js": self.tensorflow_js} 2463 ), 2464 **( 2465 {} 2466 if self.tensorflow_saved_model_bundle is None 2467 else { 2468 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2469 } 2470 ), 2471 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2472 }
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2485class LinkedModel(LinkedResourceBase): 2486 """Reference to a bioimage.io model.""" 2487 2488 id: ModelId 2489 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2511class ReproducibilityTolerance(Node, extra="allow"): 2512 """Describes what small numerical differences -- if any -- may be tolerated 2513 in the generated output when executing in different environments. 2514 2515 A tensor element *output* is considered mismatched to the **test_tensor** if 2516 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2517 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2518 2519 Motivation: 2520 For testing we can request the respective deep learning frameworks to be as 2521 reproducible as possible by setting seeds and chosing deterministic algorithms, 2522 but differences in operating systems, available hardware and installed drivers 2523 may still lead to numerical differences. 2524 """ 2525 2526 relative_tolerance: RelativeTolerance = 1e-3 2527 """Maximum relative tolerance of reproduced test tensor.""" 2528 2529 absolute_tolerance: AbsoluteTolerance = 1e-4 2530 """Maximum absolute tolerance of reproduced test tensor.""" 2531 2532 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2533 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2534 2535 output_ids: Sequence[TensorId] = () 2536 """Limits the output tensor IDs these reproducibility details apply to.""" 2537 2538 weights_formats: Sequence[WeightsFormat] = () 2539 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
Maximum number of mismatched elements/pixels per million to tolerate.
Limits the weights formats these details apply to.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2542class BioimageioConfig(Node, extra="allow"): 2543 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2544 """Tolerances to allow when reproducing the model's test outputs 2545 from the model's test inputs. 2546 Only the first entry matching tensor id and weights format is considered. 2547 """
Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2550class Config(Node, extra="allow"): 2551 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2554class ModelDescr(GenericModelDescrBase): 2555 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2556 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2557 """ 2558 2559 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2560 if TYPE_CHECKING: 2561 format_version: Literal["0.5.4"] = "0.5.4" 2562 else: 2563 format_version: Literal["0.5.4"] 2564 """Version of the bioimage.io model description specification used. 2565 When creating a new model always use the latest micro/patch version described here. 2566 The `format_version` is important for any consumer software to understand how to parse the fields. 2567 """ 2568 2569 implemented_type: ClassVar[Literal["model"]] = "model" 2570 if TYPE_CHECKING: 2571 type: Literal["model"] = "model" 2572 else: 2573 type: Literal["model"] 2574 """Specialized resource type 'model'""" 2575 2576 id: Optional[ModelId] = None 2577 """bioimage.io-wide unique resource identifier 2578 assigned by bioimage.io; version **un**specific.""" 2579 2580 authors: NotEmpty[List[Author]] 2581 """The authors are the creators of the model RDF and the primary points of contact.""" 2582 2583 documentation: FileSource_documentation 2584 """URL or relative path to a markdown file with additional documentation. 2585 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2586 The documentation should include a '#[#] Validation' (sub)section 2587 with details on how to quantitatively validate the model on unseen data.""" 2588 2589 @field_validator("documentation", mode="after") 2590 @classmethod 2591 def _validate_documentation( 2592 cls, value: FileSource_documentation 2593 ) -> FileSource_documentation: 2594 if not get_validation_context().perform_io_checks: 2595 return value 2596 2597 doc_reader = get_reader(value) 2598 doc_content = doc_reader.read().decode(encoding="utf-8") 2599 if not re.search("#.*[vV]alidation", doc_content): 2600 issue_warning( 2601 "No '# Validation' (sub)section found in {value}.", 2602 value=value, 2603 field="documentation", 2604 ) 2605 2606 return value 2607 2608 inputs: NotEmpty[Sequence[InputTensorDescr]] 2609 """Describes the input tensors expected by this model.""" 2610 2611 @field_validator("inputs", mode="after") 2612 @classmethod 2613 def _validate_input_axes( 2614 cls, inputs: Sequence[InputTensorDescr] 2615 ) -> Sequence[InputTensorDescr]: 2616 input_size_refs = cls._get_axes_with_independent_size(inputs) 2617 2618 for i, ipt in enumerate(inputs): 2619 valid_independent_refs: Dict[ 2620 Tuple[TensorId, AxisId], 2621 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2622 ] = { 2623 **{ 2624 (ipt.id, a.id): (ipt, a, a.size) 2625 for a in ipt.axes 2626 if not isinstance(a, BatchAxis) 2627 and isinstance(a.size, (int, ParameterizedSize)) 2628 }, 2629 **input_size_refs, 2630 } 2631 for a, ax in enumerate(ipt.axes): 2632 cls._validate_axis( 2633 "inputs", 2634 i=i, 2635 tensor_id=ipt.id, 2636 a=a, 2637 axis=ax, 2638 valid_independent_refs=valid_independent_refs, 2639 ) 2640 return inputs 2641 2642 @staticmethod 2643 def _validate_axis( 2644 field_name: str, 2645 i: int, 2646 tensor_id: TensorId, 2647 a: int, 2648 axis: AnyAxis, 2649 valid_independent_refs: Dict[ 2650 Tuple[TensorId, AxisId], 2651 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2652 ], 2653 ): 2654 if isinstance(axis, BatchAxis) or isinstance( 2655 axis.size, (int, ParameterizedSize, DataDependentSize) 2656 ): 2657 return 2658 elif not isinstance(axis.size, SizeReference): 2659 assert_never(axis.size) 2660 2661 # validate axis.size SizeReference 2662 ref = (axis.size.tensor_id, axis.size.axis_id) 2663 if ref not in valid_independent_refs: 2664 raise ValueError( 2665 "Invalid tensor axis reference at" 2666 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2667 ) 2668 if ref == (tensor_id, axis.id): 2669 raise ValueError( 2670 "Self-referencing not allowed for" 2671 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2672 ) 2673 if axis.type == "channel": 2674 if valid_independent_refs[ref][1].type != "channel": 2675 raise ValueError( 2676 "A channel axis' size may only reference another fixed size" 2677 + " channel axis." 2678 ) 2679 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2680 ref_size = valid_independent_refs[ref][2] 2681 assert isinstance(ref_size, int), ( 2682 "channel axis ref (another channel axis) has to specify fixed" 2683 + " size" 2684 ) 2685 generated_channel_names = [ 2686 Identifier(axis.channel_names.format(i=i)) 2687 for i in range(1, ref_size + 1) 2688 ] 2689 axis.channel_names = generated_channel_names 2690 2691 if (ax_unit := getattr(axis, "unit", None)) != ( 2692 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2693 ): 2694 raise ValueError( 2695 "The units of an axis and its reference axis need to match, but" 2696 + f" '{ax_unit}' != '{ref_unit}'." 2697 ) 2698 ref_axis = valid_independent_refs[ref][1] 2699 if isinstance(ref_axis, BatchAxis): 2700 raise ValueError( 2701 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2702 + " (a batch axis is not allowed as reference)." 2703 ) 2704 2705 if isinstance(axis, WithHalo): 2706 min_size = axis.size.get_size(axis, ref_axis, n=0) 2707 if (min_size - 2 * axis.halo) < 1: 2708 raise ValueError( 2709 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2710 + f" {axis.halo}." 2711 ) 2712 2713 input_halo = axis.halo * axis.scale / ref_axis.scale 2714 if input_halo != int(input_halo) or input_halo % 2 == 1: 2715 raise ValueError( 2716 f"input_halo {input_halo} (output_halo {axis.halo} *" 2717 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2718 + f" {tensor_id}.{axis.id}." 2719 ) 2720 2721 @model_validator(mode="after") 2722 def _validate_test_tensors(self) -> Self: 2723 if not get_validation_context().perform_io_checks: 2724 return self 2725 2726 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2727 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2728 2729 tensors = { 2730 descr.id: (descr, array) 2731 for descr, array in zip( 2732 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2733 ) 2734 } 2735 validate_tensors(tensors, tensor_origin="test_tensor") 2736 2737 output_arrays = { 2738 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2739 } 2740 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2741 if not rep_tol.absolute_tolerance: 2742 continue 2743 2744 if rep_tol.output_ids: 2745 out_arrays = { 2746 oid: a 2747 for oid, a in output_arrays.items() 2748 if oid in rep_tol.output_ids 2749 } 2750 else: 2751 out_arrays = output_arrays 2752 2753 for out_id, array in out_arrays.items(): 2754 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2755 raise ValueError( 2756 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2757 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2758 + f" (1% of the maximum value of the test tensor '{out_id}')" 2759 ) 2760 2761 return self 2762 2763 @model_validator(mode="after") 2764 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2765 ipt_refs = {t.id for t in self.inputs} 2766 out_refs = {t.id for t in self.outputs} 2767 for ipt in self.inputs: 2768 for p in ipt.preprocessing: 2769 ref = p.kwargs.get("reference_tensor") 2770 if ref is None: 2771 continue 2772 if ref not in ipt_refs: 2773 raise ValueError( 2774 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2775 + f" references are: {ipt_refs}." 2776 ) 2777 2778 for out in self.outputs: 2779 for p in out.postprocessing: 2780 ref = p.kwargs.get("reference_tensor") 2781 if ref is None: 2782 continue 2783 2784 if ref not in ipt_refs and ref not in out_refs: 2785 raise ValueError( 2786 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2787 + f" are: {ipt_refs | out_refs}." 2788 ) 2789 2790 return self 2791 2792 # TODO: use validate funcs in validate_test_tensors 2793 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2794 2795 name: Annotated[ 2796 Annotated[ 2797 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2798 ], 2799 MinLen(5), 2800 MaxLen(128), 2801 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2802 ] 2803 """A human-readable name of this model. 2804 It should be no longer than 64 characters 2805 and may only contain letter, number, underscore, minus, parentheses and spaces. 2806 We recommend to chose a name that refers to the model's task and image modality. 2807 """ 2808 2809 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2810 """Describes the output tensors.""" 2811 2812 @field_validator("outputs", mode="after") 2813 @classmethod 2814 def _validate_tensor_ids( 2815 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2816 ) -> Sequence[OutputTensorDescr]: 2817 tensor_ids = [ 2818 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2819 ] 2820 duplicate_tensor_ids: List[str] = [] 2821 seen: Set[str] = set() 2822 for t in tensor_ids: 2823 if t in seen: 2824 duplicate_tensor_ids.append(t) 2825 2826 seen.add(t) 2827 2828 if duplicate_tensor_ids: 2829 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2830 2831 return outputs 2832 2833 @staticmethod 2834 def _get_axes_with_parameterized_size( 2835 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2836 ): 2837 return { 2838 f"{t.id}.{a.id}": (t, a, a.size) 2839 for t in io 2840 for a in t.axes 2841 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2842 } 2843 2844 @staticmethod 2845 def _get_axes_with_independent_size( 2846 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2847 ): 2848 return { 2849 (t.id, a.id): (t, a, a.size) 2850 for t in io 2851 for a in t.axes 2852 if not isinstance(a, BatchAxis) 2853 and isinstance(a.size, (int, ParameterizedSize)) 2854 } 2855 2856 @field_validator("outputs", mode="after") 2857 @classmethod 2858 def _validate_output_axes( 2859 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2860 ) -> List[OutputTensorDescr]: 2861 input_size_refs = cls._get_axes_with_independent_size( 2862 info.data.get("inputs", []) 2863 ) 2864 output_size_refs = cls._get_axes_with_independent_size(outputs) 2865 2866 for i, out in enumerate(outputs): 2867 valid_independent_refs: Dict[ 2868 Tuple[TensorId, AxisId], 2869 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2870 ] = { 2871 **{ 2872 (out.id, a.id): (out, a, a.size) 2873 for a in out.axes 2874 if not isinstance(a, BatchAxis) 2875 and isinstance(a.size, (int, ParameterizedSize)) 2876 }, 2877 **input_size_refs, 2878 **output_size_refs, 2879 } 2880 for a, ax in enumerate(out.axes): 2881 cls._validate_axis( 2882 "outputs", 2883 i, 2884 out.id, 2885 a, 2886 ax, 2887 valid_independent_refs=valid_independent_refs, 2888 ) 2889 2890 return outputs 2891 2892 packaged_by: List[Author] = Field( 2893 default_factory=cast(Callable[[], List[Author]], list) 2894 ) 2895 """The persons that have packaged and uploaded this model. 2896 Only required if those persons differ from the `authors`.""" 2897 2898 parent: Optional[LinkedModel] = None 2899 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2900 2901 @model_validator(mode="after") 2902 def _validate_parent_is_not_self(self) -> Self: 2903 if self.parent is not None and self.parent.id == self.id: 2904 raise ValueError("A model description may not reference itself as parent.") 2905 2906 return self 2907 2908 run_mode: Annotated[ 2909 Optional[RunMode], 2910 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2911 ] = None 2912 """Custom run mode for this model: for more complex prediction procedures like test time 2913 data augmentation that currently cannot be expressed in the specification. 2914 No standard run modes are defined yet.""" 2915 2916 timestamp: Datetime = Field(default_factory=Datetime.now) 2917 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2918 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2919 (In Python a datetime object is valid, too).""" 2920 2921 training_data: Annotated[ 2922 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2923 Field(union_mode="left_to_right"), 2924 ] = None 2925 """The dataset used to train this model""" 2926 2927 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2928 """The weights for this model. 2929 Weights can be given for different formats, but should otherwise be equivalent. 2930 The available weight formats determine which consumers can use this model.""" 2931 2932 config: Config = Field(default_factory=Config) 2933 2934 @model_validator(mode="after") 2935 def _add_default_cover(self) -> Self: 2936 if not get_validation_context().perform_io_checks or self.covers: 2937 return self 2938 2939 try: 2940 generated_covers = generate_covers( 2941 [(t, load_array(t.test_tensor)) for t in self.inputs], 2942 [(t, load_array(t.test_tensor)) for t in self.outputs], 2943 ) 2944 except Exception as e: 2945 issue_warning( 2946 "Failed to generate cover image(s): {e}", 2947 value=self.covers, 2948 msg_context=dict(e=e), 2949 field="covers", 2950 ) 2951 else: 2952 self.covers.extend(generated_covers) 2953 2954 return self 2955 2956 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2957 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2958 assert all(isinstance(d, np.ndarray) for d in data) 2959 return data 2960 2961 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2962 data = [load_array(out.test_tensor) for out in self.outputs] 2963 assert all(isinstance(d, np.ndarray) for d in data) 2964 return data 2965 2966 @staticmethod 2967 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2968 batch_size = 1 2969 tensor_with_batchsize: Optional[TensorId] = None 2970 for tid in tensor_sizes: 2971 for aid, s in tensor_sizes[tid].items(): 2972 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2973 continue 2974 2975 if batch_size != 1: 2976 assert tensor_with_batchsize is not None 2977 raise ValueError( 2978 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2979 ) 2980 2981 batch_size = s 2982 tensor_with_batchsize = tid 2983 2984 return batch_size 2985 2986 def get_output_tensor_sizes( 2987 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2988 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2989 """Returns the tensor output sizes for given **input_sizes**. 2990 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2991 Otherwise it might be larger than the actual (valid) output""" 2992 batch_size = self.get_batch_size(input_sizes) 2993 ns = self.get_ns(input_sizes) 2994 2995 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2996 return tensor_sizes.outputs 2997 2998 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2999 """get parameter `n` for each parameterized axis 3000 such that the valid input size is >= the given input size""" 3001 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3002 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3003 for tid in input_sizes: 3004 for aid, s in input_sizes[tid].items(): 3005 size_descr = axes[tid][aid].size 3006 if isinstance(size_descr, ParameterizedSize): 3007 ret[(tid, aid)] = size_descr.get_n(s) 3008 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3009 pass 3010 else: 3011 assert_never(size_descr) 3012 3013 return ret 3014 3015 def get_tensor_sizes( 3016 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3017 ) -> _TensorSizes: 3018 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3019 return _TensorSizes( 3020 { 3021 t: { 3022 aa: axis_sizes.inputs[(tt, aa)] 3023 for tt, aa in axis_sizes.inputs 3024 if tt == t 3025 } 3026 for t in {tt for tt, _ in axis_sizes.inputs} 3027 }, 3028 { 3029 t: { 3030 aa: axis_sizes.outputs[(tt, aa)] 3031 for tt, aa in axis_sizes.outputs 3032 if tt == t 3033 } 3034 for t in {tt for tt, _ in axis_sizes.outputs} 3035 }, 3036 ) 3037 3038 def get_axis_sizes( 3039 self, 3040 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3041 batch_size: Optional[int] = None, 3042 *, 3043 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3044 ) -> _AxisSizes: 3045 """Determine input and output block shape for scale factors **ns** 3046 of parameterized input sizes. 3047 3048 Args: 3049 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3050 that is parameterized as `size = min + n * step`. 3051 batch_size: The desired size of the batch dimension. 3052 If given **batch_size** overwrites any batch size present in 3053 **max_input_shape**. Default 1. 3054 max_input_shape: Limits the derived block shapes. 3055 Each axis for which the input size, parameterized by `n`, is larger 3056 than **max_input_shape** is set to the minimal value `n_min` for which 3057 this is still true. 3058 Use this for small input samples or large values of **ns**. 3059 Or simply whenever you know the full input shape. 3060 3061 Returns: 3062 Resolved axis sizes for model inputs and outputs. 3063 """ 3064 max_input_shape = max_input_shape or {} 3065 if batch_size is None: 3066 for (_t_id, a_id), s in max_input_shape.items(): 3067 if a_id == BATCH_AXIS_ID: 3068 batch_size = s 3069 break 3070 else: 3071 batch_size = 1 3072 3073 all_axes = { 3074 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3075 } 3076 3077 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3078 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3079 3080 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3081 if isinstance(a, BatchAxis): 3082 if (t_descr.id, a.id) in ns: 3083 logger.warning( 3084 "Ignoring unexpected size increment factor (n) for batch axis" 3085 + " of tensor '{}'.", 3086 t_descr.id, 3087 ) 3088 return batch_size 3089 elif isinstance(a.size, int): 3090 if (t_descr.id, a.id) in ns: 3091 logger.warning( 3092 "Ignoring unexpected size increment factor (n) for fixed size" 3093 + " axis '{}' of tensor '{}'.", 3094 a.id, 3095 t_descr.id, 3096 ) 3097 return a.size 3098 elif isinstance(a.size, ParameterizedSize): 3099 if (t_descr.id, a.id) not in ns: 3100 raise ValueError( 3101 "Size increment factor (n) missing for parametrized axis" 3102 + f" '{a.id}' of tensor '{t_descr.id}'." 3103 ) 3104 n = ns[(t_descr.id, a.id)] 3105 s_max = max_input_shape.get((t_descr.id, a.id)) 3106 if s_max is not None: 3107 n = min(n, a.size.get_n(s_max)) 3108 3109 return a.size.get_size(n) 3110 3111 elif isinstance(a.size, SizeReference): 3112 if (t_descr.id, a.id) in ns: 3113 logger.warning( 3114 "Ignoring unexpected size increment factor (n) for axis '{}'" 3115 + " of tensor '{}' with size reference.", 3116 a.id, 3117 t_descr.id, 3118 ) 3119 assert not isinstance(a, BatchAxis) 3120 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3121 assert not isinstance(ref_axis, BatchAxis) 3122 ref_key = (a.size.tensor_id, a.size.axis_id) 3123 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3124 assert ref_size is not None, ref_key 3125 assert not isinstance(ref_size, _DataDepSize), ref_key 3126 return a.size.get_size( 3127 axis=a, 3128 ref_axis=ref_axis, 3129 ref_size=ref_size, 3130 ) 3131 elif isinstance(a.size, DataDependentSize): 3132 if (t_descr.id, a.id) in ns: 3133 logger.warning( 3134 "Ignoring unexpected increment factor (n) for data dependent" 3135 + " size axis '{}' of tensor '{}'.", 3136 a.id, 3137 t_descr.id, 3138 ) 3139 return _DataDepSize(a.size.min, a.size.max) 3140 else: 3141 assert_never(a.size) 3142 3143 # first resolve all , but the `SizeReference` input sizes 3144 for t_descr in self.inputs: 3145 for a in t_descr.axes: 3146 if not isinstance(a.size, SizeReference): 3147 s = get_axis_size(a) 3148 assert not isinstance(s, _DataDepSize) 3149 inputs[t_descr.id, a.id] = s 3150 3151 # resolve all other input axis sizes 3152 for t_descr in self.inputs: 3153 for a in t_descr.axes: 3154 if isinstance(a.size, SizeReference): 3155 s = get_axis_size(a) 3156 assert not isinstance(s, _DataDepSize) 3157 inputs[t_descr.id, a.id] = s 3158 3159 # resolve all output axis sizes 3160 for t_descr in self.outputs: 3161 for a in t_descr.axes: 3162 assert not isinstance(a.size, ParameterizedSize) 3163 s = get_axis_size(a) 3164 outputs[t_descr.id, a.id] = s 3165 3166 return _AxisSizes(inputs=inputs, outputs=outputs) 3167 3168 @model_validator(mode="before") 3169 @classmethod 3170 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3171 cls.convert_from_old_format_wo_validation(data) 3172 return data 3173 3174 @classmethod 3175 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3176 """Convert metadata following an older format version to this classes' format 3177 without validating the result. 3178 """ 3179 if ( 3180 data.get("type") == "model" 3181 and isinstance(fv := data.get("format_version"), str) 3182 and fv.count(".") == 2 3183 ): 3184 fv_parts = fv.split(".") 3185 if any(not p.isdigit() for p in fv_parts): 3186 return 3187 3188 fv_tuple = tuple(map(int, fv_parts)) 3189 3190 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3191 if fv_tuple[:2] in ((0, 3), (0, 4)): 3192 m04 = _ModelDescr_v0_4.load(data) 3193 if isinstance(m04, InvalidDescr): 3194 try: 3195 updated = _model_conv.convert_as_dict( 3196 m04 # pyright: ignore[reportArgumentType] 3197 ) 3198 except Exception as e: 3199 logger.error( 3200 "Failed to convert from invalid model 0.4 description." 3201 + f"\nerror: {e}" 3202 + "\nProceeding with model 0.5 validation without conversion." 3203 ) 3204 updated = None 3205 else: 3206 updated = _model_conv.convert_as_dict(m04) 3207 3208 if updated is not None: 3209 data.clear() 3210 data.update(updated) 3211 3212 elif fv_tuple[:2] == (0, 5): 3213 # bump patch version 3214 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2966 @staticmethod 2967 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2968 batch_size = 1 2969 tensor_with_batchsize: Optional[TensorId] = None 2970 for tid in tensor_sizes: 2971 for aid, s in tensor_sizes[tid].items(): 2972 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2973 continue 2974 2975 if batch_size != 1: 2976 assert tensor_with_batchsize is not None 2977 raise ValueError( 2978 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2979 ) 2980 2981 batch_size = s 2982 tensor_with_batchsize = tid 2983 2984 return batch_size
2986 def get_output_tensor_sizes( 2987 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2988 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2989 """Returns the tensor output sizes for given **input_sizes**. 2990 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2991 Otherwise it might be larger than the actual (valid) output""" 2992 batch_size = self.get_batch_size(input_sizes) 2993 ns = self.get_ns(input_sizes) 2994 2995 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2996 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2998 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2999 """get parameter `n` for each parameterized axis 3000 such that the valid input size is >= the given input size""" 3001 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3002 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3003 for tid in input_sizes: 3004 for aid, s in input_sizes[tid].items(): 3005 size_descr = axes[tid][aid].size 3006 if isinstance(size_descr, ParameterizedSize): 3007 ret[(tid, aid)] = size_descr.get_n(s) 3008 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3009 pass 3010 else: 3011 assert_never(size_descr) 3012 3013 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
3015 def get_tensor_sizes( 3016 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3017 ) -> _TensorSizes: 3018 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3019 return _TensorSizes( 3020 { 3021 t: { 3022 aa: axis_sizes.inputs[(tt, aa)] 3023 for tt, aa in axis_sizes.inputs 3024 if tt == t 3025 } 3026 for t in {tt for tt, _ in axis_sizes.inputs} 3027 }, 3028 { 3029 t: { 3030 aa: axis_sizes.outputs[(tt, aa)] 3031 for tt, aa in axis_sizes.outputs 3032 if tt == t 3033 } 3034 for t in {tt for tt, _ in axis_sizes.outputs} 3035 }, 3036 )
3038 def get_axis_sizes( 3039 self, 3040 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3041 batch_size: Optional[int] = None, 3042 *, 3043 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3044 ) -> _AxisSizes: 3045 """Determine input and output block shape for scale factors **ns** 3046 of parameterized input sizes. 3047 3048 Args: 3049 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3050 that is parameterized as `size = min + n * step`. 3051 batch_size: The desired size of the batch dimension. 3052 If given **batch_size** overwrites any batch size present in 3053 **max_input_shape**. Default 1. 3054 max_input_shape: Limits the derived block shapes. 3055 Each axis for which the input size, parameterized by `n`, is larger 3056 than **max_input_shape** is set to the minimal value `n_min` for which 3057 this is still true. 3058 Use this for small input samples or large values of **ns**. 3059 Or simply whenever you know the full input shape. 3060 3061 Returns: 3062 Resolved axis sizes for model inputs and outputs. 3063 """ 3064 max_input_shape = max_input_shape or {} 3065 if batch_size is None: 3066 for (_t_id, a_id), s in max_input_shape.items(): 3067 if a_id == BATCH_AXIS_ID: 3068 batch_size = s 3069 break 3070 else: 3071 batch_size = 1 3072 3073 all_axes = { 3074 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3075 } 3076 3077 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3078 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3079 3080 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3081 if isinstance(a, BatchAxis): 3082 if (t_descr.id, a.id) in ns: 3083 logger.warning( 3084 "Ignoring unexpected size increment factor (n) for batch axis" 3085 + " of tensor '{}'.", 3086 t_descr.id, 3087 ) 3088 return batch_size 3089 elif isinstance(a.size, int): 3090 if (t_descr.id, a.id) in ns: 3091 logger.warning( 3092 "Ignoring unexpected size increment factor (n) for fixed size" 3093 + " axis '{}' of tensor '{}'.", 3094 a.id, 3095 t_descr.id, 3096 ) 3097 return a.size 3098 elif isinstance(a.size, ParameterizedSize): 3099 if (t_descr.id, a.id) not in ns: 3100 raise ValueError( 3101 "Size increment factor (n) missing for parametrized axis" 3102 + f" '{a.id}' of tensor '{t_descr.id}'." 3103 ) 3104 n = ns[(t_descr.id, a.id)] 3105 s_max = max_input_shape.get((t_descr.id, a.id)) 3106 if s_max is not None: 3107 n = min(n, a.size.get_n(s_max)) 3108 3109 return a.size.get_size(n) 3110 3111 elif isinstance(a.size, SizeReference): 3112 if (t_descr.id, a.id) in ns: 3113 logger.warning( 3114 "Ignoring unexpected size increment factor (n) for axis '{}'" 3115 + " of tensor '{}' with size reference.", 3116 a.id, 3117 t_descr.id, 3118 ) 3119 assert not isinstance(a, BatchAxis) 3120 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3121 assert not isinstance(ref_axis, BatchAxis) 3122 ref_key = (a.size.tensor_id, a.size.axis_id) 3123 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3124 assert ref_size is not None, ref_key 3125 assert not isinstance(ref_size, _DataDepSize), ref_key 3126 return a.size.get_size( 3127 axis=a, 3128 ref_axis=ref_axis, 3129 ref_size=ref_size, 3130 ) 3131 elif isinstance(a.size, DataDependentSize): 3132 if (t_descr.id, a.id) in ns: 3133 logger.warning( 3134 "Ignoring unexpected increment factor (n) for data dependent" 3135 + " size axis '{}' of tensor '{}'.", 3136 a.id, 3137 t_descr.id, 3138 ) 3139 return _DataDepSize(a.size.min, a.size.max) 3140 else: 3141 assert_never(a.size) 3142 3143 # first resolve all , but the `SizeReference` input sizes 3144 for t_descr in self.inputs: 3145 for a in t_descr.axes: 3146 if not isinstance(a.size, SizeReference): 3147 s = get_axis_size(a) 3148 assert not isinstance(s, _DataDepSize) 3149 inputs[t_descr.id, a.id] = s 3150 3151 # resolve all other input axis sizes 3152 for t_descr in self.inputs: 3153 for a in t_descr.axes: 3154 if isinstance(a.size, SizeReference): 3155 s = get_axis_size(a) 3156 assert not isinstance(s, _DataDepSize) 3157 inputs[t_descr.id, a.id] = s 3158 3159 # resolve all output axis sizes 3160 for t_descr in self.outputs: 3161 for a in t_descr.axes: 3162 assert not isinstance(a.size, ParameterizedSize) 3163 s = get_axis_size(a) 3164 outputs[t_descr.id, a.id] = s 3165 3166 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3174 @classmethod 3175 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3176 """Convert metadata following an older format version to this classes' format 3177 without validating the result. 3178 """ 3179 if ( 3180 data.get("type") == "model" 3181 and isinstance(fv := data.get("format_version"), str) 3182 and fv.count(".") == 2 3183 ): 3184 fv_parts = fv.split(".") 3185 if any(not p.isdigit() for p in fv_parts): 3186 return 3187 3188 fv_tuple = tuple(map(int, fv_parts)) 3189 3190 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3191 if fv_tuple[:2] in ((0, 3), (0, 4)): 3192 m04 = _ModelDescr_v0_4.load(data) 3193 if isinstance(m04, InvalidDescr): 3194 try: 3195 updated = _model_conv.convert_as_dict( 3196 m04 # pyright: ignore[reportArgumentType] 3197 ) 3198 except Exception as e: 3199 logger.error( 3200 "Failed to convert from invalid model 0.4 description." 3201 + f"\nerror: {e}" 3202 + "\nProceeding with model 0.5 validation without conversion." 3203 ) 3204 updated = None 3205 else: 3206 updated = _model_conv.convert_as_dict(m04) 3207 3208 if updated is not None: 3209 data.clear() 3210 data.update(updated) 3211 3212 elif fv_tuple[:2] == (0, 5): 3213 # bump patch version 3214 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
3439def generate_covers( 3440 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3441 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3442) -> List[Path]: 3443 def squeeze( 3444 data: NDArray[Any], axes: Sequence[AnyAxis] 3445 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3446 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3447 if data.ndim != len(axes): 3448 raise ValueError( 3449 f"tensor shape {data.shape} does not match described axes" 3450 + f" {[a.id for a in axes]}" 3451 ) 3452 3453 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3454 return data.squeeze(), axes 3455 3456 def normalize( 3457 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3458 ) -> NDArray[np.float32]: 3459 data = data.astype("float32") 3460 data -= data.min(axis=axis, keepdims=True) 3461 data /= data.max(axis=axis, keepdims=True) + eps 3462 return data 3463 3464 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3465 original_shape = data.shape 3466 data, axes = squeeze(data, axes) 3467 3468 # take slice fom any batch or index axis if needed 3469 # and convert the first channel axis and take a slice from any additional channel axes 3470 slices: Tuple[slice, ...] = () 3471 ndim = data.ndim 3472 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3473 has_c_axis = False 3474 for i, a in enumerate(axes): 3475 s = data.shape[i] 3476 assert s > 1 3477 if ( 3478 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3479 and ndim > ndim_need 3480 ): 3481 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3482 ndim -= 1 3483 elif isinstance(a, ChannelAxis): 3484 if has_c_axis: 3485 # second channel axis 3486 data = data[slices + (slice(0, 1),)] 3487 ndim -= 1 3488 else: 3489 has_c_axis = True 3490 if s == 2: 3491 # visualize two channels with cyan and magenta 3492 data = np.concatenate( 3493 [ 3494 data[slices + (slice(1, 2),)], 3495 data[slices + (slice(0, 1),)], 3496 ( 3497 data[slices + (slice(0, 1),)] 3498 + data[slices + (slice(1, 2),)] 3499 ) 3500 / 2, # TODO: take maximum instead? 3501 ], 3502 axis=i, 3503 ) 3504 elif data.shape[i] == 3: 3505 pass # visualize 3 channels as RGB 3506 else: 3507 # visualize first 3 channels as RGB 3508 data = data[slices + (slice(3),)] 3509 3510 assert data.shape[i] == 3 3511 3512 slices += (slice(None),) 3513 3514 data, axes = squeeze(data, axes) 3515 assert len(axes) == ndim 3516 # take slice from z axis if needed 3517 slices = () 3518 if ndim > ndim_need: 3519 for i, a in enumerate(axes): 3520 s = data.shape[i] 3521 if a.id == AxisId("z"): 3522 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3523 data, axes = squeeze(data, axes) 3524 ndim -= 1 3525 break 3526 3527 slices += (slice(None),) 3528 3529 # take slice from any space or time axis 3530 slices = () 3531 3532 for i, a in enumerate(axes): 3533 if ndim <= ndim_need: 3534 break 3535 3536 s = data.shape[i] 3537 assert s > 1 3538 if isinstance( 3539 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3540 ): 3541 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3542 ndim -= 1 3543 3544 slices += (slice(None),) 3545 3546 del slices 3547 data, axes = squeeze(data, axes) 3548 assert len(axes) == ndim 3549 3550 if (has_c_axis and ndim != 3) or ndim != 2: 3551 raise ValueError( 3552 f"Failed to construct cover image from shape {original_shape}" 3553 ) 3554 3555 if not has_c_axis: 3556 assert ndim == 2 3557 data = np.repeat(data[:, :, None], 3, axis=2) 3558 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3559 ndim += 1 3560 3561 assert ndim == 3 3562 3563 # transpose axis order such that longest axis comes first... 3564 axis_order: List[int] = list(np.argsort(list(data.shape))) 3565 axis_order.reverse() 3566 # ... and channel axis is last 3567 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3568 axis_order.append(axis_order.pop(c)) 3569 axes = [axes[ao] for ao in axis_order] 3570 data = data.transpose(axis_order) 3571 3572 # h, w = data.shape[:2] 3573 # if h / w in (1.0 or 2.0): 3574 # pass 3575 # elif h / w < 2: 3576 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3577 3578 norm_along = ( 3579 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3580 ) 3581 # normalize the data and map to 8 bit 3582 data = normalize(data, norm_along) 3583 data = (data * 255).astype("uint8") 3584 3585 return data 3586 3587 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3588 assert im0.dtype == im1.dtype == np.uint8 3589 assert im0.shape == im1.shape 3590 assert im0.ndim == 3 3591 N, M, C = im0.shape 3592 assert C == 3 3593 out = np.ones((N, M, C), dtype="uint8") 3594 for c in range(C): 3595 outc = np.tril(im0[..., c]) 3596 mask = outc == 0 3597 outc[mask] = np.triu(im1[..., c])[mask] 3598 out[..., c] = outc 3599 3600 return out 3601 3602 ipt_descr, ipt = inputs[0] 3603 out_descr, out = outputs[0] 3604 3605 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3606 out_img = to_2d_image(out, out_descr.axes) 3607 3608 cover_folder = Path(mkdtemp()) 3609 if ipt_img.shape == out_img.shape: 3610 covers = [cover_folder / "cover.png"] 3611 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3612 else: 3613 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3614 imwrite(covers[0], ipt_img) 3615 imwrite(covers[1], out_img) 3616 3617 return covers