bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 ClassVar, 17 Dict, 18 Generic, 19 List, 20 Literal, 21 Mapping, 22 NamedTuple, 23 Optional, 24 Sequence, 25 Set, 26 Tuple, 27 Type, 28 TypeVar, 29 Union, 30 cast, 31) 32 33import numpy as np 34from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 35from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 36from loguru import logger 37from numpy.typing import NDArray 38from pydantic import ( 39 AfterValidator, 40 Discriminator, 41 Field, 42 RootModel, 43 Tag, 44 ValidationInfo, 45 WrapSerializer, 46 field_validator, 47 model_validator, 48) 49from typing_extensions import Annotated, Self, assert_never, get_args 50 51from .._internal.common_nodes import ( 52 InvalidDescr, 53 Node, 54 NodeWithExplicitlySetFields, 55) 56from .._internal.constants import DTYPE_LIMITS 57from .._internal.field_warning import issue_warning, warn 58from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 59from .._internal.io import FileDescr as FileDescr 60from .._internal.io import WithSuffix, YamlValue, download 61from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 62from .._internal.io_basics import Sha256 as Sha256 63from .._internal.io_utils import load_array 64from .._internal.node_converter import Converter 65from .._internal.types import ( 66 AbsoluteTolerance, 67 ImportantFileSource, 68 LowerCaseIdentifier, 69 LowerCaseIdentifierAnno, 70 MismatchedElementsPerMillion, 71 RelativeTolerance, 72 SiUnit, 73) 74from .._internal.types import Datetime as Datetime 75from .._internal.types import Identifier as Identifier 76from .._internal.types import NotEmpty as NotEmpty 77from .._internal.url import HttpUrl as HttpUrl 78from .._internal.validation_context import get_validation_context 79from .._internal.validator_annotations import RestrictCharacters 80from .._internal.version_type import Version as Version 81from .._internal.warning_levels import INFO 82from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 83from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 84from ..dataset.v0_3 import DatasetDescr as DatasetDescr 85from ..dataset.v0_3 import DatasetId as DatasetId 86from ..dataset.v0_3 import LinkedDataset as LinkedDataset 87from ..dataset.v0_3 import Uploader as Uploader 88from ..generic.v0_3 import ( 89 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 90) 91from ..generic.v0_3 import Author as Author 92from ..generic.v0_3 import BadgeDescr as BadgeDescr 93from ..generic.v0_3 import CiteEntry as CiteEntry 94from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 95from ..generic.v0_3 import ( 96 DocumentationSource, 97 GenericModelDescrBase, 98 LinkedResourceBase, 99 _author_conv, # pyright: ignore[reportPrivateUsage] 100 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 101) 102from ..generic.v0_3 import Doi as Doi 103from ..generic.v0_3 import LicenseId as LicenseId 104from ..generic.v0_3 import LinkedResource as LinkedResource 105from ..generic.v0_3 import Maintainer as Maintainer 106from ..generic.v0_3 import OrcidId as OrcidId 107from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 108from ..generic.v0_3 import ResourceId as ResourceId 109from .v0_4 import Author as _Author_v0_4 110from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 111from .v0_4 import CallableFromDepencency as CallableFromDepencency 112from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 113from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 114from .v0_4 import ClipDescr as _ClipDescr_v0_4 115from .v0_4 import ClipKwargs as ClipKwargs 116from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 117from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 118from .v0_4 import KnownRunMode as KnownRunMode 119from .v0_4 import ModelDescr as _ModelDescr_v0_4 120from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 121from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 122from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 123from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 124from .v0_4 import ProcessingKwargs as ProcessingKwargs 125from .v0_4 import RunMode as RunMode 126from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 127from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 128from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 129from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 130from .v0_4 import TensorName as _TensorName_v0_4 131from .v0_4 import WeightsFormat as WeightsFormat 132from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 133from .v0_4 import package_weights 134 135SpaceUnit = Literal[ 136 "attometer", 137 "angstrom", 138 "centimeter", 139 "decimeter", 140 "exameter", 141 "femtometer", 142 "foot", 143 "gigameter", 144 "hectometer", 145 "inch", 146 "kilometer", 147 "megameter", 148 "meter", 149 "micrometer", 150 "mile", 151 "millimeter", 152 "nanometer", 153 "parsec", 154 "petameter", 155 "picometer", 156 "terameter", 157 "yard", 158 "yoctometer", 159 "yottameter", 160 "zeptometer", 161 "zettameter", 162] 163"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 164 165TimeUnit = Literal[ 166 "attosecond", 167 "centisecond", 168 "day", 169 "decisecond", 170 "exasecond", 171 "femtosecond", 172 "gigasecond", 173 "hectosecond", 174 "hour", 175 "kilosecond", 176 "megasecond", 177 "microsecond", 178 "millisecond", 179 "minute", 180 "nanosecond", 181 "petasecond", 182 "picosecond", 183 "second", 184 "terasecond", 185 "yoctosecond", 186 "yottasecond", 187 "zeptosecond", 188 "zettasecond", 189] 190"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 191 192AxisType = Literal["batch", "channel", "index", "time", "space"] 193 194_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 195 "b": "batch", 196 "t": "time", 197 "i": "index", 198 "c": "channel", 199 "x": "space", 200 "y": "space", 201 "z": "space", 202} 203 204_AXIS_ID_MAP = { 205 "b": "batch", 206 "t": "time", 207 "i": "index", 208 "c": "channel", 209} 210 211 212class TensorId(LowerCaseIdentifier): 213 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 214 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 215 ] 216 217 218def _normalize_axis_id(a: str): 219 a = str(a) 220 normalized = _AXIS_ID_MAP.get(a, a) 221 if a != normalized: 222 logger.opt(depth=3).warning( 223 "Normalized axis id from '{}' to '{}'.", a, normalized 224 ) 225 return normalized 226 227 228class AxisId(LowerCaseIdentifier): 229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 230 Annotated[ 231 LowerCaseIdentifierAnno, 232 MaxLen(16), 233 AfterValidator(_normalize_axis_id), 234 ] 235 ] 236 237 238def _is_batch(a: str) -> bool: 239 return str(a) == "batch" 240 241 242def _is_not_batch(a: str) -> bool: 243 return not _is_batch(a) 244 245 246NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 247 248PostprocessingId = Literal[ 249 "binarize", 250 "clip", 251 "ensure_dtype", 252 "fixed_zero_mean_unit_variance", 253 "scale_linear", 254 "scale_mean_variance", 255 "scale_range", 256 "sigmoid", 257 "zero_mean_unit_variance", 258] 259PreprocessingId = Literal[ 260 "binarize", 261 "clip", 262 "ensure_dtype", 263 "scale_linear", 264 "sigmoid", 265 "zero_mean_unit_variance", 266 "scale_range", 267] 268 269 270SAME_AS_TYPE = "<same as type>" 271 272 273ParameterizedSize_N = int 274""" 275Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 276""" 277 278 279class ParameterizedSize(Node): 280 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 281 282 - **min** and **step** are given by the model description. 283 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 284 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 285 This allows to adjust the axis size more generically. 286 """ 287 288 N: ClassVar[Type[int]] = ParameterizedSize_N 289 """Positive integer to parameterize this axis""" 290 291 min: Annotated[int, Gt(0)] 292 step: Annotated[int, Gt(0)] 293 294 def validate_size(self, size: int) -> int: 295 if size < self.min: 296 raise ValueError(f"size {size} < {self.min}") 297 if (size - self.min) % self.step != 0: 298 raise ValueError( 299 f"axis of size {size} is not parameterized by `min + n*step` =" 300 + f" `{self.min} + n*{self.step}`" 301 ) 302 303 return size 304 305 def get_size(self, n: ParameterizedSize_N) -> int: 306 return self.min + self.step * n 307 308 def get_n(self, s: int) -> ParameterizedSize_N: 309 """return smallest n parameterizing a size greater or equal than `s`""" 310 return ceil((s - self.min) / self.step) 311 312 313class DataDependentSize(Node): 314 min: Annotated[int, Gt(0)] = 1 315 max: Annotated[Optional[int], Gt(1)] = None 316 317 @model_validator(mode="after") 318 def _validate_max_gt_min(self): 319 if self.max is not None and self.min >= self.max: 320 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 321 322 return self 323 324 def validate_size(self, size: int) -> int: 325 if size < self.min: 326 raise ValueError(f"size {size} < {self.min}") 327 328 if self.max is not None and size > self.max: 329 raise ValueError(f"size {size} > {self.max}") 330 331 return size 332 333 334class SizeReference(Node): 335 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 336 337 `axis.size = reference.size * reference.scale / axis.scale + offset` 338 339 Note: 340 1. The axis and the referenced axis need to have the same unit (or no unit). 341 2. Batch axes may not be referenced. 342 3. Fractions are rounded down. 343 4. If the reference axis is `concatenable` the referencing axis is assumed to be 344 `concatenable` as well with the same block order. 345 346 Example: 347 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 348 Let's assume that we want to express the image height h in relation to its width w 349 instead of only accepting input images of exactly 100*49 pixels 350 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 351 352 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 353 >>> h = SpaceInputAxis( 354 ... id=AxisId("h"), 355 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 356 ... unit="millimeter", 357 ... scale=4, 358 ... ) 359 >>> print(h.size.get_size(h, w)) 360 49 361 362 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 363 """ 364 365 tensor_id: TensorId 366 """tensor id of the reference axis""" 367 368 axis_id: AxisId 369 """axis id of the reference axis""" 370 371 offset: int = 0 372 373 def get_size( 374 self, 375 axis: Union[ 376 ChannelAxis, 377 IndexInputAxis, 378 IndexOutputAxis, 379 TimeInputAxis, 380 SpaceInputAxis, 381 TimeOutputAxis, 382 TimeOutputAxisWithHalo, 383 SpaceOutputAxis, 384 SpaceOutputAxisWithHalo, 385 ], 386 ref_axis: Union[ 387 ChannelAxis, 388 IndexInputAxis, 389 IndexOutputAxis, 390 TimeInputAxis, 391 SpaceInputAxis, 392 TimeOutputAxis, 393 TimeOutputAxisWithHalo, 394 SpaceOutputAxis, 395 SpaceOutputAxisWithHalo, 396 ], 397 n: ParameterizedSize_N = 0, 398 ref_size: Optional[int] = None, 399 ): 400 """Compute the concrete size for a given axis and its reference axis. 401 402 Args: 403 axis: The axis this `SizeReference` is the size of. 404 ref_axis: The reference axis to compute the size from. 405 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 406 and no fixed **ref_size** is given, 407 **n** is used to compute the size of the parameterized **ref_axis**. 408 ref_size: Overwrite the reference size instead of deriving it from 409 **ref_axis** 410 (**ref_axis.scale** is still used; any given **n** is ignored). 411 """ 412 assert ( 413 axis.size == self 414 ), "Given `axis.size` is not defined by this `SizeReference`" 415 416 assert ( 417 ref_axis.id == self.axis_id 418 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 419 420 assert axis.unit == ref_axis.unit, ( 421 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 422 f" but {axis.unit}!={ref_axis.unit}" 423 ) 424 if ref_size is None: 425 if isinstance(ref_axis.size, (int, float)): 426 ref_size = ref_axis.size 427 elif isinstance(ref_axis.size, ParameterizedSize): 428 ref_size = ref_axis.size.get_size(n) 429 elif isinstance(ref_axis.size, DataDependentSize): 430 raise ValueError( 431 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 432 ) 433 elif isinstance(ref_axis.size, SizeReference): 434 raise ValueError( 435 "Reference axis referenced in `SizeReference` may not be sized by a" 436 + " `SizeReference` itself." 437 ) 438 else: 439 assert_never(ref_axis.size) 440 441 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 442 443 @staticmethod 444 def _get_unit( 445 axis: Union[ 446 ChannelAxis, 447 IndexInputAxis, 448 IndexOutputAxis, 449 TimeInputAxis, 450 SpaceInputAxis, 451 TimeOutputAxis, 452 TimeOutputAxisWithHalo, 453 SpaceOutputAxis, 454 SpaceOutputAxisWithHalo, 455 ], 456 ): 457 return axis.unit 458 459 460class AxisBase(NodeWithExplicitlySetFields): 461 id: AxisId 462 """An axis id unique across all axes of one tensor.""" 463 464 description: Annotated[str, MaxLen(128)] = "" 465 466 467class WithHalo(Node): 468 halo: Annotated[int, Ge(1)] 469 """The halo should be cropped from the output tensor to avoid boundary effects. 470 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 471 To document a halo that is already cropped by the model use `size.offset` instead.""" 472 473 size: Annotated[ 474 SizeReference, 475 Field( 476 examples=[ 477 10, 478 SizeReference( 479 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 480 ).model_dump(mode="json"), 481 ] 482 ), 483 ] 484 """reference to another axis with an optional offset (see `SizeReference`)""" 485 486 487BATCH_AXIS_ID = AxisId("batch") 488 489 490class BatchAxis(AxisBase): 491 implemented_type: ClassVar[Literal["batch"]] = "batch" 492 if TYPE_CHECKING: 493 type: Literal["batch"] = "batch" 494 else: 495 type: Literal["batch"] 496 497 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 498 size: Optional[Literal[1]] = None 499 """The batch size may be fixed to 1, 500 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 501 502 @property 503 def scale(self): 504 return 1.0 505 506 @property 507 def concatenable(self): 508 return True 509 510 @property 511 def unit(self): 512 return None 513 514 515class ChannelAxis(AxisBase): 516 implemented_type: ClassVar[Literal["channel"]] = "channel" 517 if TYPE_CHECKING: 518 type: Literal["channel"] = "channel" 519 else: 520 type: Literal["channel"] 521 522 id: NonBatchAxisId = AxisId("channel") 523 channel_names: NotEmpty[List[Identifier]] 524 525 @property 526 def size(self) -> int: 527 return len(self.channel_names) 528 529 @property 530 def concatenable(self): 531 return False 532 533 @property 534 def scale(self) -> float: 535 return 1.0 536 537 @property 538 def unit(self): 539 return None 540 541 542class IndexAxisBase(AxisBase): 543 implemented_type: ClassVar[Literal["index"]] = "index" 544 if TYPE_CHECKING: 545 type: Literal["index"] = "index" 546 else: 547 type: Literal["index"] 548 549 id: NonBatchAxisId = AxisId("index") 550 551 @property 552 def scale(self) -> float: 553 return 1.0 554 555 @property 556 def unit(self): 557 return None 558 559 560class _WithInputAxisSize(Node): 561 size: Annotated[ 562 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 563 Field( 564 examples=[ 565 10, 566 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 567 SizeReference( 568 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 569 ).model_dump(mode="json"), 570 ] 571 ), 572 ] 573 """The size/length of this axis can be specified as 574 - fixed integer 575 - parameterized series of valid sizes (`ParameterizedSize`) 576 - reference to another axis with an optional offset (`SizeReference`) 577 """ 578 579 580class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 581 concatenable: bool = False 582 """If a model has a `concatenable` input axis, it can be processed blockwise, 583 splitting a longer sample axis into blocks matching its input tensor description. 584 Output axes are concatenable if they have a `SizeReference` to a concatenable 585 input axis. 586 """ 587 588 589class IndexOutputAxis(IndexAxisBase): 590 size: Annotated[ 591 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 592 Field( 593 examples=[ 594 10, 595 SizeReference( 596 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 597 ).model_dump(mode="json"), 598 ] 599 ), 600 ] 601 """The size/length of this axis can be specified as 602 - fixed integer 603 - reference to another axis with an optional offset (`SizeReference`) 604 - data dependent size using `DataDependentSize` (size is only known after model inference) 605 """ 606 607 608class TimeAxisBase(AxisBase): 609 implemented_type: ClassVar[Literal["time"]] = "time" 610 if TYPE_CHECKING: 611 type: Literal["time"] = "time" 612 else: 613 type: Literal["time"] 614 615 id: NonBatchAxisId = AxisId("time") 616 unit: Optional[TimeUnit] = None 617 scale: Annotated[float, Gt(0)] = 1.0 618 619 620class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 621 concatenable: bool = False 622 """If a model has a `concatenable` input axis, it can be processed blockwise, 623 splitting a longer sample axis into blocks matching its input tensor description. 624 Output axes are concatenable if they have a `SizeReference` to a concatenable 625 input axis. 626 """ 627 628 629class SpaceAxisBase(AxisBase): 630 implemented_type: ClassVar[Literal["space"]] = "space" 631 if TYPE_CHECKING: 632 type: Literal["space"] = "space" 633 else: 634 type: Literal["space"] 635 636 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 637 unit: Optional[SpaceUnit] = None 638 scale: Annotated[float, Gt(0)] = 1.0 639 640 641class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 642 concatenable: bool = False 643 """If a model has a `concatenable` input axis, it can be processed blockwise, 644 splitting a longer sample axis into blocks matching its input tensor description. 645 Output axes are concatenable if they have a `SizeReference` to a concatenable 646 input axis. 647 """ 648 649 650INPUT_AXIS_TYPES = ( 651 BatchAxis, 652 ChannelAxis, 653 IndexInputAxis, 654 TimeInputAxis, 655 SpaceInputAxis, 656) 657"""intended for isinstance comparisons in py<3.10""" 658 659_InputAxisUnion = Union[ 660 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 661] 662InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 663 664 665class _WithOutputAxisSize(Node): 666 size: Annotated[ 667 Union[Annotated[int, Gt(0)], SizeReference], 668 Field( 669 examples=[ 670 10, 671 SizeReference( 672 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 673 ).model_dump(mode="json"), 674 ] 675 ), 676 ] 677 """The size/length of this axis can be specified as 678 - fixed integer 679 - reference to another axis with an optional offset (see `SizeReference`) 680 """ 681 682 683class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 684 pass 685 686 687class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 688 pass 689 690 691def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 692 if isinstance(v, dict): 693 return "with_halo" if "halo" in v else "wo_halo" 694 else: 695 return "with_halo" if hasattr(v, "halo") else "wo_halo" 696 697 698_TimeOutputAxisUnion = Annotated[ 699 Union[ 700 Annotated[TimeOutputAxis, Tag("wo_halo")], 701 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 702 ], 703 Discriminator(_get_halo_axis_discriminator_value), 704] 705 706 707class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 708 pass 709 710 711class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 712 pass 713 714 715_SpaceOutputAxisUnion = Annotated[ 716 Union[ 717 Annotated[SpaceOutputAxis, Tag("wo_halo")], 718 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 719 ], 720 Discriminator(_get_halo_axis_discriminator_value), 721] 722 723 724_OutputAxisUnion = Union[ 725 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 726] 727OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 728 729OUTPUT_AXIS_TYPES = ( 730 BatchAxis, 731 ChannelAxis, 732 IndexOutputAxis, 733 TimeOutputAxis, 734 TimeOutputAxisWithHalo, 735 SpaceOutputAxis, 736 SpaceOutputAxisWithHalo, 737) 738"""intended for isinstance comparisons in py<3.10""" 739 740 741AnyAxis = Union[InputAxis, OutputAxis] 742 743ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 744"""intended for isinstance comparisons in py<3.10""" 745 746TVs = Union[ 747 NotEmpty[List[int]], 748 NotEmpty[List[float]], 749 NotEmpty[List[bool]], 750 NotEmpty[List[str]], 751] 752 753 754NominalOrOrdinalDType = Literal[ 755 "float32", 756 "float64", 757 "uint8", 758 "int8", 759 "uint16", 760 "int16", 761 "uint32", 762 "int32", 763 "uint64", 764 "int64", 765 "bool", 766] 767 768 769class NominalOrOrdinalDataDescr(Node): 770 values: TVs 771 """A fixed set of nominal or an ascending sequence of ordinal values. 772 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 773 String `values` are interpreted as labels for tensor values 0, ..., N. 774 Note: as YAML 1.2 does not natively support a "set" datatype, 775 nominal values should be given as a sequence (aka list/array) as well. 776 """ 777 778 type: Annotated[ 779 NominalOrOrdinalDType, 780 Field( 781 examples=[ 782 "float32", 783 "uint8", 784 "uint16", 785 "int64", 786 "bool", 787 ], 788 ), 789 ] = "uint8" 790 791 @model_validator(mode="after") 792 def _validate_values_match_type( 793 self, 794 ) -> Self: 795 incompatible: List[Any] = [] 796 for v in self.values: 797 if self.type == "bool": 798 if not isinstance(v, bool): 799 incompatible.append(v) 800 elif self.type in DTYPE_LIMITS: 801 if ( 802 isinstance(v, (int, float)) 803 and ( 804 v < DTYPE_LIMITS[self.type].min 805 or v > DTYPE_LIMITS[self.type].max 806 ) 807 or (isinstance(v, str) and "uint" not in self.type) 808 or (isinstance(v, float) and "int" in self.type) 809 ): 810 incompatible.append(v) 811 else: 812 incompatible.append(v) 813 814 if len(incompatible) == 5: 815 incompatible.append("...") 816 break 817 818 if incompatible: 819 raise ValueError( 820 f"data type '{self.type}' incompatible with values {incompatible}" 821 ) 822 823 return self 824 825 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 826 827 @property 828 def range(self): 829 if isinstance(self.values[0], str): 830 return 0, len(self.values) - 1 831 else: 832 return min(self.values), max(self.values) 833 834 835IntervalOrRatioDType = Literal[ 836 "float32", 837 "float64", 838 "uint8", 839 "int8", 840 "uint16", 841 "int16", 842 "uint32", 843 "int32", 844 "uint64", 845 "int64", 846] 847 848 849class IntervalOrRatioDataDescr(Node): 850 type: Annotated[ # todo: rename to dtype 851 IntervalOrRatioDType, 852 Field( 853 examples=["float32", "float64", "uint8", "uint16"], 854 ), 855 ] = "float32" 856 range: Tuple[Optional[float], Optional[float]] = ( 857 None, 858 None, 859 ) 860 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 861 `None` corresponds to min/max of what can be expressed by **type**.""" 862 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 863 scale: float = 1.0 864 """Scale for data on an interval (or ratio) scale.""" 865 offset: Optional[float] = None 866 """Offset for data on a ratio scale.""" 867 868 869TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 870 871 872class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 873 """processing base class""" 874 875 876class BinarizeKwargs(ProcessingKwargs): 877 """key word arguments for `BinarizeDescr`""" 878 879 threshold: float 880 """The fixed threshold""" 881 882 883class BinarizeAlongAxisKwargs(ProcessingKwargs): 884 """key word arguments for `BinarizeDescr`""" 885 886 threshold: NotEmpty[List[float]] 887 """The fixed threshold values along `axis`""" 888 889 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 890 """The `threshold` axis""" 891 892 893class BinarizeDescr(ProcessingDescrBase): 894 """Binarize the tensor with a fixed threshold. 895 896 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 897 will be set to one, values below the threshold to zero. 898 899 Examples: 900 - in YAML 901 ```yaml 902 postprocessing: 903 - id: binarize 904 kwargs: 905 axis: 'channel' 906 threshold: [0.25, 0.5, 0.75] 907 ``` 908 - in Python: 909 >>> postprocessing = [BinarizeDescr( 910 ... kwargs=BinarizeAlongAxisKwargs( 911 ... axis=AxisId('channel'), 912 ... threshold=[0.25, 0.5, 0.75], 913 ... ) 914 ... )] 915 """ 916 917 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 918 if TYPE_CHECKING: 919 id: Literal["binarize"] = "binarize" 920 else: 921 id: Literal["binarize"] 922 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 923 924 925class ClipDescr(ProcessingDescrBase): 926 """Set tensor values below min to min and above max to max. 927 928 See `ScaleRangeDescr` for examples. 929 """ 930 931 implemented_id: ClassVar[Literal["clip"]] = "clip" 932 if TYPE_CHECKING: 933 id: Literal["clip"] = "clip" 934 else: 935 id: Literal["clip"] 936 937 kwargs: ClipKwargs 938 939 940class EnsureDtypeKwargs(ProcessingKwargs): 941 """key word arguments for `EnsureDtypeDescr`""" 942 943 dtype: Literal[ 944 "float32", 945 "float64", 946 "uint8", 947 "int8", 948 "uint16", 949 "int16", 950 "uint32", 951 "int32", 952 "uint64", 953 "int64", 954 "bool", 955 ] 956 957 958class EnsureDtypeDescr(ProcessingDescrBase): 959 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 960 961 This can for example be used to ensure the inner neural network model gets a 962 different input tensor data type than the fully described bioimage.io model does. 963 964 Examples: 965 The described bioimage.io model (incl. preprocessing) accepts any 966 float32-compatible tensor, normalizes it with percentiles and clipping and then 967 casts it to uint8, which is what the neural network in this example expects. 968 - in YAML 969 ```yaml 970 inputs: 971 - data: 972 type: float32 # described bioimage.io model is compatible with any float32 input tensor 973 preprocessing: 974 - id: scale_range 975 kwargs: 976 axes: ['y', 'x'] 977 max_percentile: 99.8 978 min_percentile: 5.0 979 - id: clip 980 kwargs: 981 min: 0.0 982 max: 1.0 983 - id: ensure_dtype 984 kwargs: 985 dtype: uint8 986 ``` 987 - in Python: 988 >>> preprocessing = [ 989 ... ScaleRangeDescr( 990 ... kwargs=ScaleRangeKwargs( 991 ... axes= (AxisId('y'), AxisId('x')), 992 ... max_percentile= 99.8, 993 ... min_percentile= 5.0, 994 ... ) 995 ... ), 996 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 997 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 998 ... ] 999 """ 1000 1001 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1002 if TYPE_CHECKING: 1003 id: Literal["ensure_dtype"] = "ensure_dtype" 1004 else: 1005 id: Literal["ensure_dtype"] 1006 1007 kwargs: EnsureDtypeKwargs 1008 1009 1010class ScaleLinearKwargs(ProcessingKwargs): 1011 """Key word arguments for `ScaleLinearDescr`""" 1012 1013 gain: float = 1.0 1014 """multiplicative factor""" 1015 1016 offset: float = 0.0 1017 """additive term""" 1018 1019 @model_validator(mode="after") 1020 def _validate(self) -> Self: 1021 if self.gain == 1.0 and self.offset == 0.0: 1022 raise ValueError( 1023 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1024 + " != 0.0." 1025 ) 1026 1027 return self 1028 1029 1030class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1031 """Key word arguments for `ScaleLinearDescr`""" 1032 1033 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1034 """The axis of gain and offset values.""" 1035 1036 gain: Union[float, NotEmpty[List[float]]] = 1.0 1037 """multiplicative factor""" 1038 1039 offset: Union[float, NotEmpty[List[float]]] = 0.0 1040 """additive term""" 1041 1042 @model_validator(mode="after") 1043 def _validate(self) -> Self: 1044 1045 if isinstance(self.gain, list): 1046 if isinstance(self.offset, list): 1047 if len(self.gain) != len(self.offset): 1048 raise ValueError( 1049 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1050 ) 1051 else: 1052 self.offset = [float(self.offset)] * len(self.gain) 1053 elif isinstance(self.offset, list): 1054 self.gain = [float(self.gain)] * len(self.offset) 1055 else: 1056 raise ValueError( 1057 "Do not specify an `axis` for scalar gain and offset values." 1058 ) 1059 1060 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1061 raise ValueError( 1062 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1063 + " != 0.0." 1064 ) 1065 1066 return self 1067 1068 1069class ScaleLinearDescr(ProcessingDescrBase): 1070 """Fixed linear scaling. 1071 1072 Examples: 1073 1. Scale with scalar gain and offset 1074 - in YAML 1075 ```yaml 1076 preprocessing: 1077 - id: scale_linear 1078 kwargs: 1079 gain: 2.0 1080 offset: 3.0 1081 ``` 1082 - in Python: 1083 >>> preprocessing = [ 1084 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1085 ... ] 1086 1087 2. Independent scaling along an axis 1088 - in YAML 1089 ```yaml 1090 preprocessing: 1091 - id: scale_linear 1092 kwargs: 1093 axis: 'channel' 1094 gain: [1.0, 2.0, 3.0] 1095 ``` 1096 - in Python: 1097 >>> preprocessing = [ 1098 ... ScaleLinearDescr( 1099 ... kwargs=ScaleLinearAlongAxisKwargs( 1100 ... axis=AxisId("channel"), 1101 ... gain=[1.0, 2.0, 3.0], 1102 ... ) 1103 ... ) 1104 ... ] 1105 1106 """ 1107 1108 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1109 if TYPE_CHECKING: 1110 id: Literal["scale_linear"] = "scale_linear" 1111 else: 1112 id: Literal["scale_linear"] 1113 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1114 1115 1116class SigmoidDescr(ProcessingDescrBase): 1117 """The logistic sigmoid funciton, a.k.a. expit function. 1118 1119 Examples: 1120 - in YAML 1121 ```yaml 1122 postprocessing: 1123 - id: sigmoid 1124 ``` 1125 - in Python: 1126 >>> postprocessing = [SigmoidDescr()] 1127 """ 1128 1129 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1130 if TYPE_CHECKING: 1131 id: Literal["sigmoid"] = "sigmoid" 1132 else: 1133 id: Literal["sigmoid"] 1134 1135 @property 1136 def kwargs(self) -> ProcessingKwargs: 1137 """empty kwargs""" 1138 return ProcessingKwargs() 1139 1140 1141class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1142 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1143 1144 mean: float 1145 """The mean value to normalize with.""" 1146 1147 std: Annotated[float, Ge(1e-6)] 1148 """The standard deviation value to normalize with.""" 1149 1150 1151class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1152 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1153 1154 mean: NotEmpty[List[float]] 1155 """The mean value(s) to normalize with.""" 1156 1157 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1158 """The standard deviation value(s) to normalize with. 1159 Size must match `mean` values.""" 1160 1161 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1162 """The axis of the mean/std values to normalize each entry along that dimension 1163 separately.""" 1164 1165 @model_validator(mode="after") 1166 def _mean_and_std_match(self) -> Self: 1167 if len(self.mean) != len(self.std): 1168 raise ValueError( 1169 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1170 + " must match." 1171 ) 1172 1173 return self 1174 1175 1176class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1177 """Subtract a given mean and divide by the standard deviation. 1178 1179 Normalize with fixed, precomputed values for 1180 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1181 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1182 axes. 1183 1184 Examples: 1185 1. scalar value for whole tensor 1186 - in YAML 1187 ```yaml 1188 preprocessing: 1189 - id: fixed_zero_mean_unit_variance 1190 kwargs: 1191 mean: 103.5 1192 std: 13.7 1193 ``` 1194 - in Python 1195 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1196 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1197 ... )] 1198 1199 2. independently along an axis 1200 - in YAML 1201 ```yaml 1202 preprocessing: 1203 - id: fixed_zero_mean_unit_variance 1204 kwargs: 1205 axis: channel 1206 mean: [101.5, 102.5, 103.5] 1207 std: [11.7, 12.7, 13.7] 1208 ``` 1209 - in Python 1210 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1211 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1212 ... axis=AxisId("channel"), 1213 ... mean=[101.5, 102.5, 103.5], 1214 ... std=[11.7, 12.7, 13.7], 1215 ... ) 1216 ... )] 1217 """ 1218 1219 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1220 "fixed_zero_mean_unit_variance" 1221 ) 1222 if TYPE_CHECKING: 1223 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1224 else: 1225 id: Literal["fixed_zero_mean_unit_variance"] 1226 1227 kwargs: Union[ 1228 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1229 ] 1230 1231 1232class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1233 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1234 1235 axes: Annotated[ 1236 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1237 ] = None 1238 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1239 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1240 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1241 To normalize each sample independently leave out the 'batch' axis. 1242 Default: Scale all axes jointly.""" 1243 1244 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1245 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1246 1247 1248class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1249 """Subtract mean and divide by variance. 1250 1251 Examples: 1252 Subtract tensor mean and variance 1253 - in YAML 1254 ```yaml 1255 preprocessing: 1256 - id: zero_mean_unit_variance 1257 ``` 1258 - in Python 1259 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1260 """ 1261 1262 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1263 "zero_mean_unit_variance" 1264 ) 1265 if TYPE_CHECKING: 1266 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1267 else: 1268 id: Literal["zero_mean_unit_variance"] 1269 1270 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1271 default_factory=ZeroMeanUnitVarianceKwargs 1272 ) 1273 1274 1275class ScaleRangeKwargs(ProcessingKwargs): 1276 """key word arguments for `ScaleRangeDescr` 1277 1278 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1279 this processing step normalizes data to the [0, 1] intervall. 1280 For other percentiles the normalized values will partially be outside the [0, 1] 1281 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1282 normalized values to a range. 1283 """ 1284 1285 axes: Annotated[ 1286 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1287 ] = None 1288 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1289 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1290 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1291 To normalize samples independently, leave out the "batch" axis. 1292 Default: Scale all axes jointly.""" 1293 1294 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1295 """The lower percentile used to determine the value to align with zero.""" 1296 1297 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1298 """The upper percentile used to determine the value to align with one. 1299 Has to be bigger than `min_percentile`. 1300 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1301 accepting percentiles specified in the range 0.0 to 1.0.""" 1302 1303 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1304 """Epsilon for numeric stability. 1305 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1306 with `v_lower,v_upper` values at the respective percentiles.""" 1307 1308 reference_tensor: Optional[TensorId] = None 1309 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1310 For any tensor in `inputs` only input tensor references are allowed.""" 1311 1312 @field_validator("max_percentile", mode="after") 1313 @classmethod 1314 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1315 if (min_p := info.data["min_percentile"]) >= value: 1316 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1317 1318 return value 1319 1320 1321class ScaleRangeDescr(ProcessingDescrBase): 1322 """Scale with percentiles. 1323 1324 Examples: 1325 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1326 - in YAML 1327 ```yaml 1328 preprocessing: 1329 - id: scale_range 1330 kwargs: 1331 axes: ['y', 'x'] 1332 max_percentile: 99.8 1333 min_percentile: 5.0 1334 ``` 1335 - in Python 1336 >>> preprocessing = [ 1337 ... ScaleRangeDescr( 1338 ... kwargs=ScaleRangeKwargs( 1339 ... axes= (AxisId('y'), AxisId('x')), 1340 ... max_percentile= 99.8, 1341 ... min_percentile= 5.0, 1342 ... ) 1343 ... ), 1344 ... ClipDescr( 1345 ... kwargs=ClipKwargs( 1346 ... min=0.0, 1347 ... max=1.0, 1348 ... ) 1349 ... ), 1350 ... ] 1351 1352 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1353 - in YAML 1354 ```yaml 1355 preprocessing: 1356 - id: scale_range 1357 kwargs: 1358 axes: ['y', 'x'] 1359 max_percentile: 99.8 1360 min_percentile: 5.0 1361 - id: scale_range 1362 - id: clip 1363 kwargs: 1364 min: 0.0 1365 max: 1.0 1366 ``` 1367 - in Python 1368 >>> preprocessing = [ScaleRangeDescr( 1369 ... kwargs=ScaleRangeKwargs( 1370 ... axes= (AxisId('y'), AxisId('x')), 1371 ... max_percentile= 99.8, 1372 ... min_percentile= 5.0, 1373 ... ) 1374 ... )] 1375 1376 """ 1377 1378 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1379 if TYPE_CHECKING: 1380 id: Literal["scale_range"] = "scale_range" 1381 else: 1382 id: Literal["scale_range"] 1383 kwargs: ScaleRangeKwargs 1384 1385 1386class ScaleMeanVarianceKwargs(ProcessingKwargs): 1387 """key word arguments for `ScaleMeanVarianceKwargs`""" 1388 1389 reference_tensor: TensorId 1390 """Name of tensor to match.""" 1391 1392 axes: Annotated[ 1393 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1394 ] = None 1395 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1396 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1397 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1398 To normalize samples independently, leave out the 'batch' axis. 1399 Default: Scale all axes jointly.""" 1400 1401 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1402 """Epsilon for numeric stability: 1403 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1404 1405 1406class ScaleMeanVarianceDescr(ProcessingDescrBase): 1407 """Scale a tensor's data distribution to match another tensor's mean/std. 1408 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1409 """ 1410 1411 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1412 if TYPE_CHECKING: 1413 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1414 else: 1415 id: Literal["scale_mean_variance"] 1416 kwargs: ScaleMeanVarianceKwargs 1417 1418 1419PreprocessingDescr = Annotated[ 1420 Union[ 1421 BinarizeDescr, 1422 ClipDescr, 1423 EnsureDtypeDescr, 1424 ScaleLinearDescr, 1425 SigmoidDescr, 1426 FixedZeroMeanUnitVarianceDescr, 1427 ZeroMeanUnitVarianceDescr, 1428 ScaleRangeDescr, 1429 ], 1430 Discriminator("id"), 1431] 1432PostprocessingDescr = Annotated[ 1433 Union[ 1434 BinarizeDescr, 1435 ClipDescr, 1436 EnsureDtypeDescr, 1437 ScaleLinearDescr, 1438 SigmoidDescr, 1439 FixedZeroMeanUnitVarianceDescr, 1440 ZeroMeanUnitVarianceDescr, 1441 ScaleRangeDescr, 1442 ScaleMeanVarianceDescr, 1443 ], 1444 Discriminator("id"), 1445] 1446 1447IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1448 1449 1450class TensorDescrBase(Node, Generic[IO_AxisT]): 1451 id: TensorId 1452 """Tensor id. No duplicates are allowed.""" 1453 1454 description: Annotated[str, MaxLen(128)] = "" 1455 """free text description""" 1456 1457 axes: NotEmpty[Sequence[IO_AxisT]] 1458 """tensor axes""" 1459 1460 @property 1461 def shape(self): 1462 return tuple(a.size for a in self.axes) 1463 1464 @field_validator("axes", mode="after", check_fields=False) 1465 @classmethod 1466 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1467 batch_axes = [a for a in axes if a.type == "batch"] 1468 if len(batch_axes) > 1: 1469 raise ValueError( 1470 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1471 ) 1472 1473 seen_ids: Set[AxisId] = set() 1474 duplicate_axes_ids: Set[AxisId] = set() 1475 for a in axes: 1476 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1477 1478 if duplicate_axes_ids: 1479 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1480 1481 return axes 1482 1483 test_tensor: FileDescr 1484 """An example tensor to use for testing. 1485 Using the model with the test input tensors is expected to yield the test output tensors. 1486 Each test tensor has be a an ndarray in the 1487 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1488 The file extension must be '.npy'.""" 1489 1490 sample_tensor: Optional[FileDescr] = None 1491 """A sample tensor to illustrate a possible input/output for the model, 1492 The sample image primarily serves to inform a human user about an example use case 1493 and is typically stored as .hdf5, .png or .tiff. 1494 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1495 (numpy's `.npy` format is not supported). 1496 The image dimensionality has to match the number of axes specified in this tensor description. 1497 """ 1498 1499 @model_validator(mode="after") 1500 def _validate_sample_tensor(self) -> Self: 1501 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1502 return self 1503 1504 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1505 tensor: NDArray[Any] = imread( 1506 local.path.read_bytes(), 1507 extension=PurePosixPath(local.original_file_name).suffix, 1508 ) 1509 n_dims = len(tensor.squeeze().shape) 1510 n_dims_min = n_dims_max = len(self.axes) 1511 1512 for a in self.axes: 1513 if isinstance(a, BatchAxis): 1514 n_dims_min -= 1 1515 elif isinstance(a.size, int): 1516 if a.size == 1: 1517 n_dims_min -= 1 1518 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1519 if a.size.min == 1: 1520 n_dims_min -= 1 1521 elif isinstance(a.size, SizeReference): 1522 if a.size.offset < 2: 1523 # size reference may result in singleton axis 1524 n_dims_min -= 1 1525 else: 1526 assert_never(a.size) 1527 1528 n_dims_min = max(0, n_dims_min) 1529 if n_dims < n_dims_min or n_dims > n_dims_max: 1530 raise ValueError( 1531 f"Expected sample tensor to have {n_dims_min} to" 1532 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1533 ) 1534 1535 return self 1536 1537 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1538 IntervalOrRatioDataDescr() 1539 ) 1540 """Description of the tensor's data values, optionally per channel. 1541 If specified per channel, the data `type` needs to match across channels.""" 1542 1543 @property 1544 def dtype( 1545 self, 1546 ) -> Literal[ 1547 "float32", 1548 "float64", 1549 "uint8", 1550 "int8", 1551 "uint16", 1552 "int16", 1553 "uint32", 1554 "int32", 1555 "uint64", 1556 "int64", 1557 "bool", 1558 ]: 1559 """dtype as specified under `data.type` or `data[i].type`""" 1560 if isinstance(self.data, collections.abc.Sequence): 1561 return self.data[0].type 1562 else: 1563 return self.data.type 1564 1565 @field_validator("data", mode="after") 1566 @classmethod 1567 def _check_data_type_across_channels( 1568 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1569 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1570 if not isinstance(value, list): 1571 return value 1572 1573 dtypes = {t.type for t in value} 1574 if len(dtypes) > 1: 1575 raise ValueError( 1576 "Tensor data descriptions per channel need to agree in their data" 1577 + f" `type`, but found {dtypes}." 1578 ) 1579 1580 return value 1581 1582 @model_validator(mode="after") 1583 def _check_data_matches_channelaxis(self) -> Self: 1584 if not isinstance(self.data, (list, tuple)): 1585 return self 1586 1587 for a in self.axes: 1588 if isinstance(a, ChannelAxis): 1589 size = a.size 1590 assert isinstance(size, int) 1591 break 1592 else: 1593 return self 1594 1595 if len(self.data) != size: 1596 raise ValueError( 1597 f"Got tensor data descriptions for {len(self.data)} channels, but" 1598 + f" '{a.id}' axis has size {size}." 1599 ) 1600 1601 return self 1602 1603 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1604 if len(array.shape) != len(self.axes): 1605 raise ValueError( 1606 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1607 + f" incompatible with {len(self.axes)} axes." 1608 ) 1609 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1610 1611 1612class InputTensorDescr(TensorDescrBase[InputAxis]): 1613 id: TensorId = TensorId("input") 1614 """Input tensor id. 1615 No duplicates are allowed across all inputs and outputs.""" 1616 1617 optional: bool = False 1618 """indicates that this tensor may be `None`""" 1619 1620 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1621 """Description of how this input should be preprocessed. 1622 1623 notes: 1624 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1625 to ensure an input tensor's data type matches the input tensor's data description. 1626 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1627 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1628 changing the data type. 1629 """ 1630 1631 @model_validator(mode="after") 1632 def _validate_preprocessing_kwargs(self) -> Self: 1633 axes_ids = [a.id for a in self.axes] 1634 for p in self.preprocessing: 1635 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1636 if kwargs_axes is None: 1637 continue 1638 1639 if not isinstance(kwargs_axes, collections.abc.Sequence): 1640 raise ValueError( 1641 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1642 ) 1643 1644 if any(a not in axes_ids for a in kwargs_axes): 1645 raise ValueError( 1646 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1647 ) 1648 1649 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1650 dtype = self.data.type 1651 else: 1652 dtype = self.data[0].type 1653 1654 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1655 if not self.preprocessing or not isinstance( 1656 self.preprocessing[0], EnsureDtypeDescr 1657 ): 1658 self.preprocessing.insert( 1659 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1660 ) 1661 1662 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1663 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1664 self.preprocessing.append( 1665 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1666 ) 1667 1668 return self 1669 1670 1671def convert_axes( 1672 axes: str, 1673 *, 1674 shape: Union[ 1675 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1676 ], 1677 tensor_type: Literal["input", "output"], 1678 halo: Optional[Sequence[int]], 1679 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1680): 1681 ret: List[AnyAxis] = [] 1682 for i, a in enumerate(axes): 1683 axis_type = _AXIS_TYPE_MAP.get(a, a) 1684 if axis_type == "batch": 1685 ret.append(BatchAxis()) 1686 continue 1687 1688 scale = 1.0 1689 if isinstance(shape, _ParameterizedInputShape_v0_4): 1690 if shape.step[i] == 0: 1691 size = shape.min[i] 1692 else: 1693 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1694 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1695 ref_t = str(shape.reference_tensor) 1696 if ref_t.count(".") == 1: 1697 t_id, orig_a_id = ref_t.split(".") 1698 else: 1699 t_id = ref_t 1700 orig_a_id = a 1701 1702 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1703 if not (orig_scale := shape.scale[i]): 1704 # old way to insert a new axis dimension 1705 size = int(2 * shape.offset[i]) 1706 else: 1707 scale = 1 / orig_scale 1708 if axis_type in ("channel", "index"): 1709 # these axes no longer have a scale 1710 offset_from_scale = orig_scale * size_refs.get( 1711 _TensorName_v0_4(t_id), {} 1712 ).get(orig_a_id, 0) 1713 else: 1714 offset_from_scale = 0 1715 size = SizeReference( 1716 tensor_id=TensorId(t_id), 1717 axis_id=AxisId(a_id), 1718 offset=int(offset_from_scale + 2 * shape.offset[i]), 1719 ) 1720 else: 1721 size = shape[i] 1722 1723 if axis_type == "time": 1724 if tensor_type == "input": 1725 ret.append(TimeInputAxis(size=size, scale=scale)) 1726 else: 1727 assert not isinstance(size, ParameterizedSize) 1728 if halo is None: 1729 ret.append(TimeOutputAxis(size=size, scale=scale)) 1730 else: 1731 assert not isinstance(size, int) 1732 ret.append( 1733 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1734 ) 1735 1736 elif axis_type == "index": 1737 if tensor_type == "input": 1738 ret.append(IndexInputAxis(size=size)) 1739 else: 1740 if isinstance(size, ParameterizedSize): 1741 size = DataDependentSize(min=size.min) 1742 1743 ret.append(IndexOutputAxis(size=size)) 1744 elif axis_type == "channel": 1745 assert not isinstance(size, ParameterizedSize) 1746 if isinstance(size, SizeReference): 1747 warnings.warn( 1748 "Conversion of channel size from an implicit output shape may be" 1749 + " wrong" 1750 ) 1751 ret.append( 1752 ChannelAxis( 1753 channel_names=[ 1754 Identifier(f"channel{i}") for i in range(size.offset) 1755 ] 1756 ) 1757 ) 1758 else: 1759 ret.append( 1760 ChannelAxis( 1761 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1762 ) 1763 ) 1764 elif axis_type == "space": 1765 if tensor_type == "input": 1766 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1767 else: 1768 assert not isinstance(size, ParameterizedSize) 1769 if halo is None or halo[i] == 0: 1770 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1771 elif isinstance(size, int): 1772 raise NotImplementedError( 1773 f"output axis with halo and fixed size (here {size}) not allowed" 1774 ) 1775 else: 1776 ret.append( 1777 SpaceOutputAxisWithHalo( 1778 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1779 ) 1780 ) 1781 1782 return ret 1783 1784 1785def _axes_letters_to_ids( 1786 axes: Optional[str], 1787) -> Optional[List[AxisId]]: 1788 if axes is None: 1789 return None 1790 1791 return [AxisId(a) for a in axes] 1792 1793 1794def _get_complement_v04_axis( 1795 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1796) -> Optional[AxisId]: 1797 if axes is None: 1798 return None 1799 1800 non_complement_axes = set(axes) | {"b"} 1801 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1802 if len(complement_axes) > 1: 1803 raise ValueError( 1804 f"Expected none or a single complement axis, but axes '{axes}' " 1805 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1806 ) 1807 1808 return None if not complement_axes else AxisId(complement_axes[0]) 1809 1810 1811def _convert_proc( 1812 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1813 tensor_axes: Sequence[str], 1814) -> Union[PreprocessingDescr, PostprocessingDescr]: 1815 if isinstance(p, _BinarizeDescr_v0_4): 1816 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1817 elif isinstance(p, _ClipDescr_v0_4): 1818 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1819 elif isinstance(p, _SigmoidDescr_v0_4): 1820 return SigmoidDescr() 1821 elif isinstance(p, _ScaleLinearDescr_v0_4): 1822 axes = _axes_letters_to_ids(p.kwargs.axes) 1823 if p.kwargs.axes is None: 1824 axis = None 1825 else: 1826 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1827 1828 if axis is None: 1829 assert not isinstance(p.kwargs.gain, list) 1830 assert not isinstance(p.kwargs.offset, list) 1831 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1832 else: 1833 kwargs = ScaleLinearAlongAxisKwargs( 1834 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1835 ) 1836 return ScaleLinearDescr(kwargs=kwargs) 1837 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1838 return ScaleMeanVarianceDescr( 1839 kwargs=ScaleMeanVarianceKwargs( 1840 axes=_axes_letters_to_ids(p.kwargs.axes), 1841 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1842 eps=p.kwargs.eps, 1843 ) 1844 ) 1845 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1846 if p.kwargs.mode == "fixed": 1847 mean = p.kwargs.mean 1848 std = p.kwargs.std 1849 assert mean is not None 1850 assert std is not None 1851 1852 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1853 1854 if axis is None: 1855 return FixedZeroMeanUnitVarianceDescr( 1856 kwargs=FixedZeroMeanUnitVarianceKwargs( 1857 mean=mean, std=std # pyright: ignore[reportArgumentType] 1858 ) 1859 ) 1860 else: 1861 if not isinstance(mean, list): 1862 mean = [float(mean)] 1863 if not isinstance(std, list): 1864 std = [float(std)] 1865 1866 return FixedZeroMeanUnitVarianceDescr( 1867 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1868 axis=axis, mean=mean, std=std 1869 ) 1870 ) 1871 1872 else: 1873 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1874 if p.kwargs.mode == "per_dataset": 1875 axes = [AxisId("batch")] + axes 1876 if not axes: 1877 axes = None 1878 return ZeroMeanUnitVarianceDescr( 1879 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1880 ) 1881 1882 elif isinstance(p, _ScaleRangeDescr_v0_4): 1883 return ScaleRangeDescr( 1884 kwargs=ScaleRangeKwargs( 1885 axes=_axes_letters_to_ids(p.kwargs.axes), 1886 min_percentile=p.kwargs.min_percentile, 1887 max_percentile=p.kwargs.max_percentile, 1888 eps=p.kwargs.eps, 1889 ) 1890 ) 1891 else: 1892 assert_never(p) 1893 1894 1895class _InputTensorConv( 1896 Converter[ 1897 _InputTensorDescr_v0_4, 1898 InputTensorDescr, 1899 ImportantFileSource, 1900 Optional[ImportantFileSource], 1901 Mapping[_TensorName_v0_4, Mapping[str, int]], 1902 ] 1903): 1904 def _convert( 1905 self, 1906 src: _InputTensorDescr_v0_4, 1907 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1908 test_tensor: ImportantFileSource, 1909 sample_tensor: Optional[ImportantFileSource], 1910 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1911 ) -> "InputTensorDescr | dict[str, Any]": 1912 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1913 src.axes, 1914 shape=src.shape, 1915 tensor_type="input", 1916 halo=None, 1917 size_refs=size_refs, 1918 ) 1919 prep: List[PreprocessingDescr] = [] 1920 for p in src.preprocessing: 1921 cp = _convert_proc(p, src.axes) 1922 assert not isinstance(cp, ScaleMeanVarianceDescr) 1923 prep.append(cp) 1924 1925 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 1926 1927 return tgt( 1928 axes=axes, 1929 id=TensorId(str(src.name)), 1930 test_tensor=FileDescr(source=test_tensor), 1931 sample_tensor=( 1932 None if sample_tensor is None else FileDescr(source=sample_tensor) 1933 ), 1934 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 1935 preprocessing=prep, 1936 ) 1937 1938 1939_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 1940 1941 1942class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1943 id: TensorId = TensorId("output") 1944 """Output tensor id. 1945 No duplicates are allowed across all inputs and outputs.""" 1946 1947 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1948 """Description of how this output should be postprocessed. 1949 1950 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1951 If not given this is added to cast to this tensor's `data.type`. 1952 """ 1953 1954 @model_validator(mode="after") 1955 def _validate_postprocessing_kwargs(self) -> Self: 1956 axes_ids = [a.id for a in self.axes] 1957 for p in self.postprocessing: 1958 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1959 if kwargs_axes is None: 1960 continue 1961 1962 if not isinstance(kwargs_axes, collections.abc.Sequence): 1963 raise ValueError( 1964 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1965 ) 1966 1967 if any(a not in axes_ids for a in kwargs_axes): 1968 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1969 1970 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1971 dtype = self.data.type 1972 else: 1973 dtype = self.data[0].type 1974 1975 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1976 if not self.postprocessing or not isinstance( 1977 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1978 ): 1979 self.postprocessing.append( 1980 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1981 ) 1982 return self 1983 1984 1985class _OutputTensorConv( 1986 Converter[ 1987 _OutputTensorDescr_v0_4, 1988 OutputTensorDescr, 1989 ImportantFileSource, 1990 Optional[ImportantFileSource], 1991 Mapping[_TensorName_v0_4, Mapping[str, int]], 1992 ] 1993): 1994 def _convert( 1995 self, 1996 src: _OutputTensorDescr_v0_4, 1997 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 1998 test_tensor: ImportantFileSource, 1999 sample_tensor: Optional[ImportantFileSource], 2000 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2001 ) -> "OutputTensorDescr | dict[str, Any]": 2002 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2003 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2004 src.axes, 2005 shape=src.shape, 2006 tensor_type="output", 2007 halo=src.halo, 2008 size_refs=size_refs, 2009 ) 2010 data_descr: Dict[str, Any] = dict(type=src.data_type) 2011 if data_descr["type"] == "bool": 2012 data_descr["values"] = [False, True] 2013 2014 return tgt( 2015 axes=axes, 2016 id=TensorId(str(src.name)), 2017 test_tensor=FileDescr(source=test_tensor), 2018 sample_tensor=( 2019 None if sample_tensor is None else FileDescr(source=sample_tensor) 2020 ), 2021 data=data_descr, # pyright: ignore[reportArgumentType] 2022 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2023 ) 2024 2025 2026_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2027 2028 2029TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2030 2031 2032def validate_tensors( 2033 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2034 tensor_origin: Literal[ 2035 "test_tensor" 2036 ], # for more precise error messages, e.g. 'test_tensor' 2037): 2038 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2039 2040 def e_msg(d: TensorDescr): 2041 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2042 2043 for descr, array in tensors.values(): 2044 try: 2045 axis_sizes = descr.get_axis_sizes_for_array(array) 2046 except ValueError as e: 2047 raise ValueError(f"{e_msg(descr)} {e}") 2048 else: 2049 all_tensor_axes[descr.id] = { 2050 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2051 } 2052 2053 for descr, array in tensors.values(): 2054 if descr.dtype in ("float32", "float64"): 2055 invalid_test_tensor_dtype = array.dtype.name not in ( 2056 "float32", 2057 "float64", 2058 "uint8", 2059 "int8", 2060 "uint16", 2061 "int16", 2062 "uint32", 2063 "int32", 2064 "uint64", 2065 "int64", 2066 ) 2067 else: 2068 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2069 2070 if invalid_test_tensor_dtype: 2071 raise ValueError( 2072 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2073 + f" match described dtype '{descr.dtype}'" 2074 ) 2075 2076 if array.min() > -1e-4 and array.max() < 1e-4: 2077 raise ValueError( 2078 "Output values are too small for reliable testing." 2079 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2080 ) 2081 2082 for a in descr.axes: 2083 actual_size = all_tensor_axes[descr.id][a.id][1] 2084 if a.size is None: 2085 continue 2086 2087 if isinstance(a.size, int): 2088 if actual_size != a.size: 2089 raise ValueError( 2090 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2091 + f"has incompatible size {actual_size}, expected {a.size}" 2092 ) 2093 elif isinstance(a.size, ParameterizedSize): 2094 _ = a.size.validate_size(actual_size) 2095 elif isinstance(a.size, DataDependentSize): 2096 _ = a.size.validate_size(actual_size) 2097 elif isinstance(a.size, SizeReference): 2098 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2099 if ref_tensor_axes is None: 2100 raise ValueError( 2101 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2102 + f" reference '{a.size.tensor_id}'" 2103 ) 2104 2105 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2106 if ref_axis is None or ref_size is None: 2107 raise ValueError( 2108 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2109 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2110 ) 2111 2112 if a.unit != ref_axis.unit: 2113 raise ValueError( 2114 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2115 + " axis and reference axis to have the same `unit`, but" 2116 + f" {a.unit}!={ref_axis.unit}" 2117 ) 2118 2119 if actual_size != ( 2120 expected_size := ( 2121 ref_size * ref_axis.scale / a.scale + a.size.offset 2122 ) 2123 ): 2124 raise ValueError( 2125 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2126 + f" {actual_size} invalid for referenced size {ref_size};" 2127 + f" expected {expected_size}" 2128 ) 2129 else: 2130 assert_never(a.size) 2131 2132 2133class EnvironmentFileDescr(FileDescr): 2134 source: Annotated[ 2135 ImportantFileSource, 2136 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2137 Field( 2138 examples=["environment.yaml"], 2139 ), 2140 ] 2141 """∈📦 Conda environment file. 2142 Allows to specify custom dependencies, see conda docs: 2143 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2144 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2145 """ 2146 2147 2148class _ArchitectureCallableDescr(Node): 2149 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2150 """Identifier of the callable that returns a torch.nn.Module instance.""" 2151 2152 kwargs: Dict[str, YamlValue] = Field(default_factory=dict) 2153 """key word arguments for the `callable`""" 2154 2155 2156class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2157 pass 2158 2159 2160class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2161 import_from: str 2162 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2163 2164 2165ArchitectureDescr = Annotated[ 2166 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], 2167 Field(union_mode="left_to_right"), 2168] 2169 2170 2171class _ArchFileConv( 2172 Converter[ 2173 _CallableFromFile_v0_4, 2174 ArchitectureFromFileDescr, 2175 Optional[Sha256], 2176 Dict[str, Any], 2177 ] 2178): 2179 def _convert( 2180 self, 2181 src: _CallableFromFile_v0_4, 2182 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2183 sha256: Optional[Sha256], 2184 kwargs: Dict[str, Any], 2185 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2186 if src.startswith("http") and src.count(":") == 2: 2187 http, source, callable_ = src.split(":") 2188 source = ":".join((http, source)) 2189 elif not src.startswith("http") and src.count(":") == 1: 2190 source, callable_ = src.split(":") 2191 else: 2192 source = str(src) 2193 callable_ = str(src) 2194 return tgt( 2195 callable=Identifier(callable_), 2196 source=cast(ImportantFileSource, source), 2197 sha256=sha256, 2198 kwargs=kwargs, 2199 ) 2200 2201 2202_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2203 2204 2205class _ArchLibConv( 2206 Converter[ 2207 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2208 ] 2209): 2210 def _convert( 2211 self, 2212 src: _CallableFromDepencency_v0_4, 2213 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2214 kwargs: Dict[str, Any], 2215 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2216 *mods, callable_ = src.split(".") 2217 import_from = ".".join(mods) 2218 return tgt( 2219 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2220 ) 2221 2222 2223_arch_lib_conv = _ArchLibConv( 2224 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2225) 2226 2227 2228class WeightsEntryDescrBase(FileDescr): 2229 type: ClassVar[WeightsFormat] 2230 weights_format_name: ClassVar[str] # human readable 2231 2232 source: ImportantFileSource 2233 """∈📦 The weights file.""" 2234 2235 authors: Optional[List[Author]] = None 2236 """Authors 2237 Either the person(s) that have trained this model resulting in the original weights file. 2238 (If this is the initial weights entry, i.e. it does not have a `parent`) 2239 Or the person(s) who have converted the weights to this weights format. 2240 (If this is a child weight, i.e. it has a `parent` field) 2241 """ 2242 2243 parent: Annotated[ 2244 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2245 ] = None 2246 """The source weights these weights were converted from. 2247 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2248 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2249 All weight entries except one (the initial set of weights resulting from training the model), 2250 need to have this field.""" 2251 2252 comment: str = "" 2253 """A comment about this weights entry, for example how these weights were created.""" 2254 2255 @model_validator(mode="after") 2256 def check_parent_is_not_self(self) -> Self: 2257 if self.type == self.parent: 2258 raise ValueError("Weights entry can't be it's own parent.") 2259 2260 return self 2261 2262 2263class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2264 type = "keras_hdf5" 2265 weights_format_name: ClassVar[str] = "Keras HDF5" 2266 tensorflow_version: Version 2267 """TensorFlow version used to create these weights.""" 2268 2269 2270class OnnxWeightsDescr(WeightsEntryDescrBase): 2271 type = "onnx" 2272 weights_format_name: ClassVar[str] = "ONNX" 2273 opset_version: Annotated[int, Ge(7)] 2274 """ONNX opset version""" 2275 2276 2277class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2278 type = "pytorch_state_dict" 2279 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2280 architecture: ArchitectureDescr 2281 pytorch_version: Version 2282 """Version of the PyTorch library used. 2283 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2284 """ 2285 dependencies: Optional[EnvironmentFileDescr] = None 2286 """Custom depencies beyond pytorch. 2287 The conda environment file should include pytorch and any version pinning has to be compatible with 2288 `pytorch_version`. 2289 """ 2290 2291 2292class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2293 type = "tensorflow_js" 2294 weights_format_name: ClassVar[str] = "Tensorflow.js" 2295 tensorflow_version: Version 2296 """Version of the TensorFlow library used.""" 2297 2298 source: ImportantFileSource 2299 """∈📦 The multi-file weights. 2300 All required files/folders should be a zip archive.""" 2301 2302 2303class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2304 type = "tensorflow_saved_model_bundle" 2305 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2306 tensorflow_version: Version 2307 """Version of the TensorFlow library used.""" 2308 2309 dependencies: Optional[EnvironmentFileDescr] = None 2310 """Custom dependencies beyond tensorflow. 2311 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 2312 2313 source: ImportantFileSource 2314 """∈📦 The multi-file weights. 2315 All required files/folders should be a zip archive.""" 2316 2317 2318class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2319 type = "torchscript" 2320 weights_format_name: ClassVar[str] = "TorchScript" 2321 pytorch_version: Version 2322 """Version of the PyTorch library used.""" 2323 2324 2325class WeightsDescr(Node): 2326 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2327 onnx: Optional[OnnxWeightsDescr] = None 2328 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2329 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2330 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2331 None 2332 ) 2333 torchscript: Optional[TorchscriptWeightsDescr] = None 2334 2335 @model_validator(mode="after") 2336 def check_entries(self) -> Self: 2337 entries = {wtype for wtype, entry in self if entry is not None} 2338 2339 if not entries: 2340 raise ValueError("Missing weights entry") 2341 2342 entries_wo_parent = { 2343 wtype 2344 for wtype, entry in self 2345 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2346 } 2347 if len(entries_wo_parent) != 1: 2348 issue_warning( 2349 "Exactly one weights entry may not specify the `parent` field (got" 2350 + " {value}). That entry is considered the original set of model weights." 2351 + " Other weight formats are created through conversion of the orignal or" 2352 + " already converted weights. They have to reference the weights format" 2353 + " they were converted from as their `parent`.", 2354 value=len(entries_wo_parent), 2355 field="weights", 2356 ) 2357 2358 for wtype, entry in self: 2359 if entry is None: 2360 continue 2361 2362 assert hasattr(entry, "type") 2363 assert hasattr(entry, "parent") 2364 assert wtype == entry.type 2365 if ( 2366 entry.parent is not None and entry.parent not in entries 2367 ): # self reference checked for `parent` field 2368 raise ValueError( 2369 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2370 + f" formats: {entries}" 2371 ) 2372 2373 return self 2374 2375 def __getitem__( 2376 self, 2377 key: Literal[ 2378 "keras_hdf5", 2379 "onnx", 2380 "pytorch_state_dict", 2381 "tensorflow_js", 2382 "tensorflow_saved_model_bundle", 2383 "torchscript", 2384 ], 2385 ): 2386 if key == "keras_hdf5": 2387 ret = self.keras_hdf5 2388 elif key == "onnx": 2389 ret = self.onnx 2390 elif key == "pytorch_state_dict": 2391 ret = self.pytorch_state_dict 2392 elif key == "tensorflow_js": 2393 ret = self.tensorflow_js 2394 elif key == "tensorflow_saved_model_bundle": 2395 ret = self.tensorflow_saved_model_bundle 2396 elif key == "torchscript": 2397 ret = self.torchscript 2398 else: 2399 raise KeyError(key) 2400 2401 if ret is None: 2402 raise KeyError(key) 2403 2404 return ret 2405 2406 @property 2407 def available_formats(self): 2408 return { 2409 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2410 **({} if self.onnx is None else {"onnx": self.onnx}), 2411 **( 2412 {} 2413 if self.pytorch_state_dict is None 2414 else {"pytorch_state_dict": self.pytorch_state_dict} 2415 ), 2416 **( 2417 {} 2418 if self.tensorflow_js is None 2419 else {"tensorflow_js": self.tensorflow_js} 2420 ), 2421 **( 2422 {} 2423 if self.tensorflow_saved_model_bundle is None 2424 else { 2425 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2426 } 2427 ), 2428 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2429 } 2430 2431 @property 2432 def missing_formats(self): 2433 return { 2434 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2435 } 2436 2437 2438class ModelId(ResourceId): 2439 pass 2440 2441 2442class LinkedModel(LinkedResourceBase): 2443 """Reference to a bioimage.io model.""" 2444 2445 id: ModelId 2446 """A valid model `id` from the bioimage.io collection.""" 2447 2448 2449class _DataDepSize(NamedTuple): 2450 min: int 2451 max: Optional[int] 2452 2453 2454class _AxisSizes(NamedTuple): 2455 """the lenghts of all axes of model inputs and outputs""" 2456 2457 inputs: Dict[Tuple[TensorId, AxisId], int] 2458 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2459 2460 2461class _TensorSizes(NamedTuple): 2462 """_AxisSizes as nested dicts""" 2463 2464 inputs: Dict[TensorId, Dict[AxisId, int]] 2465 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2466 2467 2468class ReproducibilityTolerance(Node, extra="allow"): 2469 """Describes what small numerical differences -- if any -- may be tolerated 2470 in the generated output when executing in different environments. 2471 2472 A tensor element *output* is considered mismatched to the **test_tensor** if 2473 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2474 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2475 2476 Motivation: 2477 For testing we can request the respective deep learning frameworks to be as 2478 reproducible as possible by setting seeds and chosing deterministic algorithms, 2479 but differences in operating systems, available hardware and installed drivers 2480 may still lead to numerical differences. 2481 """ 2482 2483 relative_tolerance: RelativeTolerance = 1e-3 2484 """Maximum relative tolerance of reproduced test tensor.""" 2485 2486 absolute_tolerance: AbsoluteTolerance = 1e-4 2487 """Maximum absolute tolerance of reproduced test tensor.""" 2488 2489 mismatched_elements_per_million: MismatchedElementsPerMillion = 0 2490 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2491 2492 output_ids: Sequence[TensorId] = () 2493 """Limits the output tensor IDs these reproducibility details apply to.""" 2494 2495 weights_formats: Sequence[WeightsFormat] = () 2496 """Limits the weights formats these details apply to.""" 2497 2498 2499class BioimageioConfig(Node, extra="allow"): 2500 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2501 """Tolerances to allow when reproducing the model's test outputs 2502 from the model's test inputs. 2503 Only the first entry matching tensor id and weights format is considered. 2504 """ 2505 2506 2507class Config(Node, extra="allow"): 2508 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 2509 2510 2511class ModelDescr(GenericModelDescrBase): 2512 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2513 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2514 """ 2515 2516 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2517 if TYPE_CHECKING: 2518 format_version: Literal["0.5.4"] = "0.5.4" 2519 else: 2520 format_version: Literal["0.5.4"] 2521 """Version of the bioimage.io model description specification used. 2522 When creating a new model always use the latest micro/patch version described here. 2523 The `format_version` is important for any consumer software to understand how to parse the fields. 2524 """ 2525 2526 implemented_type: ClassVar[Literal["model"]] = "model" 2527 if TYPE_CHECKING: 2528 type: Literal["model"] = "model" 2529 else: 2530 type: Literal["model"] 2531 """Specialized resource type 'model'""" 2532 2533 id: Optional[ModelId] = None 2534 """bioimage.io-wide unique resource identifier 2535 assigned by bioimage.io; version **un**specific.""" 2536 2537 authors: NotEmpty[List[Author]] 2538 """The authors are the creators of the model RDF and the primary points of contact.""" 2539 2540 documentation: Annotated[ 2541 DocumentationSource, 2542 Field( 2543 examples=[ 2544 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2545 "README.md", 2546 ], 2547 ), 2548 ] 2549 """∈📦 URL or relative path to a markdown file with additional documentation. 2550 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2551 The documentation should include a '#[#] Validation' (sub)section 2552 with details on how to quantitatively validate the model on unseen data.""" 2553 2554 @field_validator("documentation", mode="after") 2555 @classmethod 2556 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2557 if not get_validation_context().perform_io_checks: 2558 return value 2559 2560 doc_path = download(value).path 2561 doc_content = doc_path.read_text(encoding="utf-8") 2562 assert isinstance(doc_content, str) 2563 if not re.search("#.*[vV]alidation", doc_content): 2564 issue_warning( 2565 "No '# Validation' (sub)section found in {value}.", 2566 value=value, 2567 field="documentation", 2568 ) 2569 2570 return value 2571 2572 inputs: NotEmpty[Sequence[InputTensorDescr]] 2573 """Describes the input tensors expected by this model.""" 2574 2575 @field_validator("inputs", mode="after") 2576 @classmethod 2577 def _validate_input_axes( 2578 cls, inputs: Sequence[InputTensorDescr] 2579 ) -> Sequence[InputTensorDescr]: 2580 input_size_refs = cls._get_axes_with_independent_size(inputs) 2581 2582 for i, ipt in enumerate(inputs): 2583 valid_independent_refs: Dict[ 2584 Tuple[TensorId, AxisId], 2585 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2586 ] = { 2587 **{ 2588 (ipt.id, a.id): (ipt, a, a.size) 2589 for a in ipt.axes 2590 if not isinstance(a, BatchAxis) 2591 and isinstance(a.size, (int, ParameterizedSize)) 2592 }, 2593 **input_size_refs, 2594 } 2595 for a, ax in enumerate(ipt.axes): 2596 cls._validate_axis( 2597 "inputs", 2598 i=i, 2599 tensor_id=ipt.id, 2600 a=a, 2601 axis=ax, 2602 valid_independent_refs=valid_independent_refs, 2603 ) 2604 return inputs 2605 2606 @staticmethod 2607 def _validate_axis( 2608 field_name: str, 2609 i: int, 2610 tensor_id: TensorId, 2611 a: int, 2612 axis: AnyAxis, 2613 valid_independent_refs: Dict[ 2614 Tuple[TensorId, AxisId], 2615 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2616 ], 2617 ): 2618 if isinstance(axis, BatchAxis) or isinstance( 2619 axis.size, (int, ParameterizedSize, DataDependentSize) 2620 ): 2621 return 2622 elif not isinstance(axis.size, SizeReference): 2623 assert_never(axis.size) 2624 2625 # validate axis.size SizeReference 2626 ref = (axis.size.tensor_id, axis.size.axis_id) 2627 if ref not in valid_independent_refs: 2628 raise ValueError( 2629 "Invalid tensor axis reference at" 2630 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2631 ) 2632 if ref == (tensor_id, axis.id): 2633 raise ValueError( 2634 "Self-referencing not allowed for" 2635 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2636 ) 2637 if axis.type == "channel": 2638 if valid_independent_refs[ref][1].type != "channel": 2639 raise ValueError( 2640 "A channel axis' size may only reference another fixed size" 2641 + " channel axis." 2642 ) 2643 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2644 ref_size = valid_independent_refs[ref][2] 2645 assert isinstance(ref_size, int), ( 2646 "channel axis ref (another channel axis) has to specify fixed" 2647 + " size" 2648 ) 2649 generated_channel_names = [ 2650 Identifier(axis.channel_names.format(i=i)) 2651 for i in range(1, ref_size + 1) 2652 ] 2653 axis.channel_names = generated_channel_names 2654 2655 if (ax_unit := getattr(axis, "unit", None)) != ( 2656 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2657 ): 2658 raise ValueError( 2659 "The units of an axis and its reference axis need to match, but" 2660 + f" '{ax_unit}' != '{ref_unit}'." 2661 ) 2662 ref_axis = valid_independent_refs[ref][1] 2663 if isinstance(ref_axis, BatchAxis): 2664 raise ValueError( 2665 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2666 + " (a batch axis is not allowed as reference)." 2667 ) 2668 2669 if isinstance(axis, WithHalo): 2670 min_size = axis.size.get_size(axis, ref_axis, n=0) 2671 if (min_size - 2 * axis.halo) < 1: 2672 raise ValueError( 2673 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2674 + f" {axis.halo}." 2675 ) 2676 2677 input_halo = axis.halo * axis.scale / ref_axis.scale 2678 if input_halo != int(input_halo) or input_halo % 2 == 1: 2679 raise ValueError( 2680 f"input_halo {input_halo} (output_halo {axis.halo} *" 2681 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2682 + f" {tensor_id}.{axis.id}." 2683 ) 2684 2685 @model_validator(mode="after") 2686 def _validate_test_tensors(self) -> Self: 2687 if not get_validation_context().perform_io_checks: 2688 return self 2689 2690 test_output_arrays = [ 2691 load_array(descr.test_tensor.download().path) for descr in self.outputs 2692 ] 2693 test_input_arrays = [ 2694 load_array(descr.test_tensor.download().path) for descr in self.inputs 2695 ] 2696 2697 tensors = { 2698 descr.id: (descr, array) 2699 for descr, array in zip( 2700 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2701 ) 2702 } 2703 validate_tensors(tensors, tensor_origin="test_tensor") 2704 2705 output_arrays = { 2706 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2707 } 2708 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2709 if not rep_tol.absolute_tolerance: 2710 continue 2711 2712 if rep_tol.output_ids: 2713 out_arrays = { 2714 oid: a 2715 for oid, a in output_arrays.items() 2716 if oid in rep_tol.output_ids 2717 } 2718 else: 2719 out_arrays = output_arrays 2720 2721 for out_id, array in out_arrays.items(): 2722 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2723 raise ValueError( 2724 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2725 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2726 + f" (1% of the maximum value of the test tensor '{out_id}')" 2727 ) 2728 2729 return self 2730 2731 @model_validator(mode="after") 2732 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2733 ipt_refs = {t.id for t in self.inputs} 2734 out_refs = {t.id for t in self.outputs} 2735 for ipt in self.inputs: 2736 for p in ipt.preprocessing: 2737 ref = p.kwargs.get("reference_tensor") 2738 if ref is None: 2739 continue 2740 if ref not in ipt_refs: 2741 raise ValueError( 2742 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2743 + f" references are: {ipt_refs}." 2744 ) 2745 2746 for out in self.outputs: 2747 for p in out.postprocessing: 2748 ref = p.kwargs.get("reference_tensor") 2749 if ref is None: 2750 continue 2751 2752 if ref not in ipt_refs and ref not in out_refs: 2753 raise ValueError( 2754 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2755 + f" are: {ipt_refs | out_refs}." 2756 ) 2757 2758 return self 2759 2760 # TODO: use validate funcs in validate_test_tensors 2761 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2762 2763 name: Annotated[ 2764 Annotated[ 2765 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2766 ], 2767 MinLen(5), 2768 MaxLen(128), 2769 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2770 ] 2771 """A human-readable name of this model. 2772 It should be no longer than 64 characters 2773 and may only contain letter, number, underscore, minus, parentheses and spaces. 2774 We recommend to chose a name that refers to the model's task and image modality. 2775 """ 2776 2777 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2778 """Describes the output tensors.""" 2779 2780 @field_validator("outputs", mode="after") 2781 @classmethod 2782 def _validate_tensor_ids( 2783 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2784 ) -> Sequence[OutputTensorDescr]: 2785 tensor_ids = [ 2786 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2787 ] 2788 duplicate_tensor_ids: List[str] = [] 2789 seen: Set[str] = set() 2790 for t in tensor_ids: 2791 if t in seen: 2792 duplicate_tensor_ids.append(t) 2793 2794 seen.add(t) 2795 2796 if duplicate_tensor_ids: 2797 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2798 2799 return outputs 2800 2801 @staticmethod 2802 def _get_axes_with_parameterized_size( 2803 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2804 ): 2805 return { 2806 f"{t.id}.{a.id}": (t, a, a.size) 2807 for t in io 2808 for a in t.axes 2809 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2810 } 2811 2812 @staticmethod 2813 def _get_axes_with_independent_size( 2814 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2815 ): 2816 return { 2817 (t.id, a.id): (t, a, a.size) 2818 for t in io 2819 for a in t.axes 2820 if not isinstance(a, BatchAxis) 2821 and isinstance(a.size, (int, ParameterizedSize)) 2822 } 2823 2824 @field_validator("outputs", mode="after") 2825 @classmethod 2826 def _validate_output_axes( 2827 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2828 ) -> List[OutputTensorDescr]: 2829 input_size_refs = cls._get_axes_with_independent_size( 2830 info.data.get("inputs", []) 2831 ) 2832 output_size_refs = cls._get_axes_with_independent_size(outputs) 2833 2834 for i, out in enumerate(outputs): 2835 valid_independent_refs: Dict[ 2836 Tuple[TensorId, AxisId], 2837 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2838 ] = { 2839 **{ 2840 (out.id, a.id): (out, a, a.size) 2841 for a in out.axes 2842 if not isinstance(a, BatchAxis) 2843 and isinstance(a.size, (int, ParameterizedSize)) 2844 }, 2845 **input_size_refs, 2846 **output_size_refs, 2847 } 2848 for a, ax in enumerate(out.axes): 2849 cls._validate_axis( 2850 "outputs", 2851 i, 2852 out.id, 2853 a, 2854 ax, 2855 valid_independent_refs=valid_independent_refs, 2856 ) 2857 2858 return outputs 2859 2860 packaged_by: List[Author] = Field(default_factory=list) 2861 """The persons that have packaged and uploaded this model. 2862 Only required if those persons differ from the `authors`.""" 2863 2864 parent: Optional[LinkedModel] = None 2865 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2866 2867 @model_validator(mode="after") 2868 def _validate_parent_is_not_self(self) -> Self: 2869 if self.parent is not None and self.parent.id == self.id: 2870 raise ValueError("A model description may not reference itself as parent.") 2871 2872 return self 2873 2874 run_mode: Annotated[ 2875 Optional[RunMode], 2876 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2877 ] = None 2878 """Custom run mode for this model: for more complex prediction procedures like test time 2879 data augmentation that currently cannot be expressed in the specification. 2880 No standard run modes are defined yet.""" 2881 2882 timestamp: Datetime = Field(default_factory=Datetime.now) 2883 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2884 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2885 (In Python a datetime object is valid, too).""" 2886 2887 training_data: Annotated[ 2888 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2889 Field(union_mode="left_to_right"), 2890 ] = None 2891 """The dataset used to train this model""" 2892 2893 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2894 """The weights for this model. 2895 Weights can be given for different formats, but should otherwise be equivalent. 2896 The available weight formats determine which consumers can use this model.""" 2897 2898 config: Config = Field(default_factory=Config) 2899 2900 @model_validator(mode="after") 2901 def _add_default_cover(self) -> Self: 2902 if not get_validation_context().perform_io_checks or self.covers: 2903 return self 2904 2905 try: 2906 generated_covers = generate_covers( 2907 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2908 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2909 ) 2910 except Exception as e: 2911 issue_warning( 2912 "Failed to generate cover image(s): {e}", 2913 value=self.covers, 2914 msg_context=dict(e=e), 2915 field="covers", 2916 ) 2917 else: 2918 self.covers.extend(generated_covers) 2919 2920 return self 2921 2922 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2923 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2924 assert all(isinstance(d, np.ndarray) for d in data) 2925 return data 2926 2927 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2928 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2929 assert all(isinstance(d, np.ndarray) for d in data) 2930 return data 2931 2932 @staticmethod 2933 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2934 batch_size = 1 2935 tensor_with_batchsize: Optional[TensorId] = None 2936 for tid in tensor_sizes: 2937 for aid, s in tensor_sizes[tid].items(): 2938 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2939 continue 2940 2941 if batch_size != 1: 2942 assert tensor_with_batchsize is not None 2943 raise ValueError( 2944 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2945 ) 2946 2947 batch_size = s 2948 tensor_with_batchsize = tid 2949 2950 return batch_size 2951 2952 def get_output_tensor_sizes( 2953 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2954 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2955 """Returns the tensor output sizes for given **input_sizes**. 2956 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2957 Otherwise it might be larger than the actual (valid) output""" 2958 batch_size = self.get_batch_size(input_sizes) 2959 ns = self.get_ns(input_sizes) 2960 2961 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2962 return tensor_sizes.outputs 2963 2964 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2965 """get parameter `n` for each parameterized axis 2966 such that the valid input size is >= the given input size""" 2967 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2968 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2969 for tid in input_sizes: 2970 for aid, s in input_sizes[tid].items(): 2971 size_descr = axes[tid][aid].size 2972 if isinstance(size_descr, ParameterizedSize): 2973 ret[(tid, aid)] = size_descr.get_n(s) 2974 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2975 pass 2976 else: 2977 assert_never(size_descr) 2978 2979 return ret 2980 2981 def get_tensor_sizes( 2982 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2983 ) -> _TensorSizes: 2984 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2985 return _TensorSizes( 2986 { 2987 t: { 2988 aa: axis_sizes.inputs[(tt, aa)] 2989 for tt, aa in axis_sizes.inputs 2990 if tt == t 2991 } 2992 for t in {tt for tt, _ in axis_sizes.inputs} 2993 }, 2994 { 2995 t: { 2996 aa: axis_sizes.outputs[(tt, aa)] 2997 for tt, aa in axis_sizes.outputs 2998 if tt == t 2999 } 3000 for t in {tt for tt, _ in axis_sizes.outputs} 3001 }, 3002 ) 3003 3004 def get_axis_sizes( 3005 self, 3006 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3007 batch_size: Optional[int] = None, 3008 *, 3009 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3010 ) -> _AxisSizes: 3011 """Determine input and output block shape for scale factors **ns** 3012 of parameterized input sizes. 3013 3014 Args: 3015 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3016 that is parameterized as `size = min + n * step`. 3017 batch_size: The desired size of the batch dimension. 3018 If given **batch_size** overwrites any batch size present in 3019 **max_input_shape**. Default 1. 3020 max_input_shape: Limits the derived block shapes. 3021 Each axis for which the input size, parameterized by `n`, is larger 3022 than **max_input_shape** is set to the minimal value `n_min` for which 3023 this is still true. 3024 Use this for small input samples or large values of **ns**. 3025 Or simply whenever you know the full input shape. 3026 3027 Returns: 3028 Resolved axis sizes for model inputs and outputs. 3029 """ 3030 max_input_shape = max_input_shape or {} 3031 if batch_size is None: 3032 for (_t_id, a_id), s in max_input_shape.items(): 3033 if a_id == BATCH_AXIS_ID: 3034 batch_size = s 3035 break 3036 else: 3037 batch_size = 1 3038 3039 all_axes = { 3040 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3041 } 3042 3043 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3044 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3045 3046 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3047 if isinstance(a, BatchAxis): 3048 if (t_descr.id, a.id) in ns: 3049 logger.warning( 3050 "Ignoring unexpected size increment factor (n) for batch axis" 3051 + " of tensor '{}'.", 3052 t_descr.id, 3053 ) 3054 return batch_size 3055 elif isinstance(a.size, int): 3056 if (t_descr.id, a.id) in ns: 3057 logger.warning( 3058 "Ignoring unexpected size increment factor (n) for fixed size" 3059 + " axis '{}' of tensor '{}'.", 3060 a.id, 3061 t_descr.id, 3062 ) 3063 return a.size 3064 elif isinstance(a.size, ParameterizedSize): 3065 if (t_descr.id, a.id) not in ns: 3066 raise ValueError( 3067 "Size increment factor (n) missing for parametrized axis" 3068 + f" '{a.id}' of tensor '{t_descr.id}'." 3069 ) 3070 n = ns[(t_descr.id, a.id)] 3071 s_max = max_input_shape.get((t_descr.id, a.id)) 3072 if s_max is not None: 3073 n = min(n, a.size.get_n(s_max)) 3074 3075 return a.size.get_size(n) 3076 3077 elif isinstance(a.size, SizeReference): 3078 if (t_descr.id, a.id) in ns: 3079 logger.warning( 3080 "Ignoring unexpected size increment factor (n) for axis '{}'" 3081 + " of tensor '{}' with size reference.", 3082 a.id, 3083 t_descr.id, 3084 ) 3085 assert not isinstance(a, BatchAxis) 3086 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3087 assert not isinstance(ref_axis, BatchAxis) 3088 ref_key = (a.size.tensor_id, a.size.axis_id) 3089 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3090 assert ref_size is not None, ref_key 3091 assert not isinstance(ref_size, _DataDepSize), ref_key 3092 return a.size.get_size( 3093 axis=a, 3094 ref_axis=ref_axis, 3095 ref_size=ref_size, 3096 ) 3097 elif isinstance(a.size, DataDependentSize): 3098 if (t_descr.id, a.id) in ns: 3099 logger.warning( 3100 "Ignoring unexpected increment factor (n) for data dependent" 3101 + " size axis '{}' of tensor '{}'.", 3102 a.id, 3103 t_descr.id, 3104 ) 3105 return _DataDepSize(a.size.min, a.size.max) 3106 else: 3107 assert_never(a.size) 3108 3109 # first resolve all , but the `SizeReference` input sizes 3110 for t_descr in self.inputs: 3111 for a in t_descr.axes: 3112 if not isinstance(a.size, SizeReference): 3113 s = get_axis_size(a) 3114 assert not isinstance(s, _DataDepSize) 3115 inputs[t_descr.id, a.id] = s 3116 3117 # resolve all other input axis sizes 3118 for t_descr in self.inputs: 3119 for a in t_descr.axes: 3120 if isinstance(a.size, SizeReference): 3121 s = get_axis_size(a) 3122 assert not isinstance(s, _DataDepSize) 3123 inputs[t_descr.id, a.id] = s 3124 3125 # resolve all output axis sizes 3126 for t_descr in self.outputs: 3127 for a in t_descr.axes: 3128 assert not isinstance(a.size, ParameterizedSize) 3129 s = get_axis_size(a) 3130 outputs[t_descr.id, a.id] = s 3131 3132 return _AxisSizes(inputs=inputs, outputs=outputs) 3133 3134 @model_validator(mode="before") 3135 @classmethod 3136 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3137 cls.convert_from_old_format_wo_validation(data) 3138 return data 3139 3140 @classmethod 3141 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3142 """Convert metadata following an older format version to this classes' format 3143 without validating the result. 3144 """ 3145 if ( 3146 data.get("type") == "model" 3147 and isinstance(fv := data.get("format_version"), str) 3148 and fv.count(".") == 2 3149 ): 3150 fv_parts = fv.split(".") 3151 if any(not p.isdigit() for p in fv_parts): 3152 return 3153 3154 fv_tuple = tuple(map(int, fv_parts)) 3155 3156 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3157 if fv_tuple[:2] in ((0, 3), (0, 4)): 3158 m04 = _ModelDescr_v0_4.load(data) 3159 if isinstance(m04, InvalidDescr): 3160 try: 3161 updated = _model_conv.convert_as_dict( 3162 m04 # pyright: ignore[reportArgumentType] 3163 ) 3164 except Exception as e: 3165 logger.error( 3166 "Failed to convert from invalid model 0.4 description." 3167 + f"\nerror: {e}" 3168 + "\nProceeding with model 0.5 validation without conversion." 3169 ) 3170 updated = None 3171 else: 3172 updated = _model_conv.convert_as_dict(m04) 3173 3174 if updated is not None: 3175 data.clear() 3176 data.update(updated) 3177 3178 elif fv_tuple[:2] == (0, 5): 3179 # bump patch version 3180 data["format_version"] = cls.implemented_format_version 3181 3182 3183class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3184 def _convert( 3185 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3186 ) -> "ModelDescr | dict[str, Any]": 3187 name = "".join( 3188 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3189 for c in src.name 3190 ) 3191 3192 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3193 conv = ( 3194 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3195 ) 3196 return None if auths is None else [conv(a) for a in auths] 3197 3198 if TYPE_CHECKING: 3199 arch_file_conv = _arch_file_conv.convert 3200 arch_lib_conv = _arch_lib_conv.convert 3201 else: 3202 arch_file_conv = _arch_file_conv.convert_as_dict 3203 arch_lib_conv = _arch_lib_conv.convert_as_dict 3204 3205 input_size_refs = { 3206 ipt.name: { 3207 a: s 3208 for a, s in zip( 3209 ipt.axes, 3210 ( 3211 ipt.shape.min 3212 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3213 else ipt.shape 3214 ), 3215 ) 3216 } 3217 for ipt in src.inputs 3218 if ipt.shape 3219 } 3220 output_size_refs = { 3221 **{ 3222 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3223 for out in src.outputs 3224 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3225 }, 3226 **input_size_refs, 3227 } 3228 3229 return tgt( 3230 attachments=( 3231 [] 3232 if src.attachments is None 3233 else [FileDescr(source=f) for f in src.attachments.files] 3234 ), 3235 authors=[ 3236 _author_conv.convert_as_dict(a) for a in src.authors 3237 ], # pyright: ignore[reportArgumentType] 3238 cite=[ 3239 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 3240 ], # pyright: ignore[reportArgumentType] 3241 config=src.config, # pyright: ignore[reportArgumentType] 3242 covers=src.covers, 3243 description=src.description, 3244 documentation=src.documentation, 3245 format_version="0.5.4", 3246 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3247 icon=src.icon, 3248 id=None if src.id is None else ModelId(src.id), 3249 id_emoji=src.id_emoji, 3250 license=src.license, # type: ignore 3251 links=src.links, 3252 maintainers=[ 3253 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 3254 ], # pyright: ignore[reportArgumentType] 3255 name=name, 3256 tags=src.tags, 3257 type=src.type, 3258 uploader=src.uploader, 3259 version=src.version, 3260 inputs=[ # pyright: ignore[reportArgumentType] 3261 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3262 for ipt, tt, st, in zip( 3263 src.inputs, 3264 src.test_inputs, 3265 src.sample_inputs or [None] * len(src.test_inputs), 3266 ) 3267 ], 3268 outputs=[ # pyright: ignore[reportArgumentType] 3269 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3270 for out, tt, st, in zip( 3271 src.outputs, 3272 src.test_outputs, 3273 src.sample_outputs or [None] * len(src.test_outputs), 3274 ) 3275 ], 3276 parent=( 3277 None 3278 if src.parent is None 3279 else LinkedModel( 3280 id=ModelId( 3281 str(src.parent.id) 3282 + ( 3283 "" 3284 if src.parent.version_number is None 3285 else f"/{src.parent.version_number}" 3286 ) 3287 ) 3288 ) 3289 ), 3290 training_data=( 3291 None 3292 if src.training_data is None 3293 else ( 3294 LinkedDataset( 3295 id=DatasetId( 3296 str(src.training_data.id) 3297 + ( 3298 "" 3299 if src.training_data.version_number is None 3300 else f"/{src.training_data.version_number}" 3301 ) 3302 ) 3303 ) 3304 if isinstance(src.training_data, LinkedDataset02) 3305 else src.training_data 3306 ) 3307 ), 3308 packaged_by=[ 3309 _author_conv.convert_as_dict(a) for a in src.packaged_by 3310 ], # pyright: ignore[reportArgumentType] 3311 run_mode=src.run_mode, 3312 timestamp=src.timestamp, 3313 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3314 keras_hdf5=(w := src.weights.keras_hdf5) 3315 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3316 authors=conv_authors(w.authors), 3317 source=w.source, 3318 tensorflow_version=w.tensorflow_version or Version("1.15"), 3319 parent=w.parent, 3320 ), 3321 onnx=(w := src.weights.onnx) 3322 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3323 source=w.source, 3324 authors=conv_authors(w.authors), 3325 parent=w.parent, 3326 opset_version=w.opset_version or 15, 3327 ), 3328 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3329 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3330 source=w.source, 3331 authors=conv_authors(w.authors), 3332 parent=w.parent, 3333 architecture=( 3334 arch_file_conv( 3335 w.architecture, 3336 w.architecture_sha256, 3337 w.kwargs, 3338 ) 3339 if isinstance(w.architecture, _CallableFromFile_v0_4) 3340 else arch_lib_conv(w.architecture, w.kwargs) 3341 ), 3342 pytorch_version=w.pytorch_version or Version("1.10"), 3343 dependencies=( 3344 None 3345 if w.dependencies is None 3346 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 3347 source=cast( 3348 ImportantFileSource, 3349 str(deps := w.dependencies)[ 3350 ( 3351 len("conda:") 3352 if str(deps).startswith("conda:") 3353 else 0 3354 ) : 3355 ], 3356 ) 3357 ) 3358 ), 3359 ), 3360 tensorflow_js=(w := src.weights.tensorflow_js) 3361 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3362 source=w.source, 3363 authors=conv_authors(w.authors), 3364 parent=w.parent, 3365 tensorflow_version=w.tensorflow_version or Version("1.15"), 3366 ), 3367 tensorflow_saved_model_bundle=( 3368 w := src.weights.tensorflow_saved_model_bundle 3369 ) 3370 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3371 authors=conv_authors(w.authors), 3372 parent=w.parent, 3373 source=w.source, 3374 tensorflow_version=w.tensorflow_version or Version("1.15"), 3375 dependencies=( 3376 None 3377 if w.dependencies is None 3378 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 3379 source=cast( 3380 ImportantFileSource, 3381 ( 3382 str(w.dependencies)[len("conda:") :] 3383 if str(w.dependencies).startswith("conda:") 3384 else str(w.dependencies) 3385 ), 3386 ) 3387 ) 3388 ), 3389 ), 3390 torchscript=(w := src.weights.torchscript) 3391 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3392 source=w.source, 3393 authors=conv_authors(w.authors), 3394 parent=w.parent, 3395 pytorch_version=w.pytorch_version or Version("1.10"), 3396 ), 3397 ), 3398 ) 3399 3400 3401_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3402 3403 3404# create better cover images for 3d data and non-image outputs 3405def generate_covers( 3406 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3407 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3408) -> List[Path]: 3409 def squeeze( 3410 data: NDArray[Any], axes: Sequence[AnyAxis] 3411 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3412 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3413 if data.ndim != len(axes): 3414 raise ValueError( 3415 f"tensor shape {data.shape} does not match described axes" 3416 + f" {[a.id for a in axes]}" 3417 ) 3418 3419 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3420 return data.squeeze(), axes 3421 3422 def normalize( 3423 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3424 ) -> NDArray[np.float32]: 3425 data = data.astype("float32") 3426 data -= data.min(axis=axis, keepdims=True) 3427 data /= data.max(axis=axis, keepdims=True) + eps 3428 return data 3429 3430 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3431 original_shape = data.shape 3432 data, axes = squeeze(data, axes) 3433 3434 # take slice fom any batch or index axis if needed 3435 # and convert the first channel axis and take a slice from any additional channel axes 3436 slices: Tuple[slice, ...] = () 3437 ndim = data.ndim 3438 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3439 has_c_axis = False 3440 for i, a in enumerate(axes): 3441 s = data.shape[i] 3442 assert s > 1 3443 if ( 3444 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3445 and ndim > ndim_need 3446 ): 3447 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3448 ndim -= 1 3449 elif isinstance(a, ChannelAxis): 3450 if has_c_axis: 3451 # second channel axis 3452 data = data[slices + (slice(0, 1),)] 3453 ndim -= 1 3454 else: 3455 has_c_axis = True 3456 if s == 2: 3457 # visualize two channels with cyan and magenta 3458 data = np.concatenate( 3459 [ 3460 data[slices + (slice(1, 2),)], 3461 data[slices + (slice(0, 1),)], 3462 ( 3463 data[slices + (slice(0, 1),)] 3464 + data[slices + (slice(1, 2),)] 3465 ) 3466 / 2, # TODO: take maximum instead? 3467 ], 3468 axis=i, 3469 ) 3470 elif data.shape[i] == 3: 3471 pass # visualize 3 channels as RGB 3472 else: 3473 # visualize first 3 channels as RGB 3474 data = data[slices + (slice(3),)] 3475 3476 assert data.shape[i] == 3 3477 3478 slices += (slice(None),) 3479 3480 data, axes = squeeze(data, axes) 3481 assert len(axes) == ndim 3482 # take slice from z axis if needed 3483 slices = () 3484 if ndim > ndim_need: 3485 for i, a in enumerate(axes): 3486 s = data.shape[i] 3487 if a.id == AxisId("z"): 3488 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3489 data, axes = squeeze(data, axes) 3490 ndim -= 1 3491 break 3492 3493 slices += (slice(None),) 3494 3495 # take slice from any space or time axis 3496 slices = () 3497 3498 for i, a in enumerate(axes): 3499 if ndim <= ndim_need: 3500 break 3501 3502 s = data.shape[i] 3503 assert s > 1 3504 if isinstance( 3505 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3506 ): 3507 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3508 ndim -= 1 3509 3510 slices += (slice(None),) 3511 3512 del slices 3513 data, axes = squeeze(data, axes) 3514 assert len(axes) == ndim 3515 3516 if (has_c_axis and ndim != 3) or ndim != 2: 3517 raise ValueError( 3518 f"Failed to construct cover image from shape {original_shape}" 3519 ) 3520 3521 if not has_c_axis: 3522 assert ndim == 2 3523 data = np.repeat(data[:, :, None], 3, axis=2) 3524 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3525 ndim += 1 3526 3527 assert ndim == 3 3528 3529 # transpose axis order such that longest axis comes first... 3530 axis_order = list(np.argsort(list(data.shape))) 3531 axis_order.reverse() 3532 # ... and channel axis is last 3533 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3534 axis_order.append(axis_order.pop(c)) 3535 axes = [axes[ao] for ao in axis_order] 3536 data = data.transpose(axis_order) 3537 3538 # h, w = data.shape[:2] 3539 # if h / w in (1.0 or 2.0): 3540 # pass 3541 # elif h / w < 2: 3542 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3543 3544 norm_along = ( 3545 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3546 ) 3547 # normalize the data and map to 8 bit 3548 data = normalize(data, norm_along) 3549 data = (data * 255).astype("uint8") 3550 3551 return data 3552 3553 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3554 assert im0.dtype == im1.dtype == np.uint8 3555 assert im0.shape == im1.shape 3556 assert im0.ndim == 3 3557 N, M, C = im0.shape 3558 assert C == 3 3559 out = np.ones((N, M, C), dtype="uint8") 3560 for c in range(C): 3561 outc = np.tril(im0[..., c]) 3562 mask = outc == 0 3563 outc[mask] = np.triu(im1[..., c])[mask] 3564 out[..., c] = outc 3565 3566 return out 3567 3568 ipt_descr, ipt = inputs[0] 3569 out_descr, out = outputs[0] 3570 3571 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3572 out_img = to_2d_image(out, out_descr.axes) 3573 3574 cover_folder = Path(mkdtemp()) 3575 if ipt_img.shape == out_img.shape: 3576 covers = [cover_folder / "cover.png"] 3577 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3578 else: 3579 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3580 imwrite(covers[0], ipt_img) 3581 imwrite(covers[1], out_img) 3582 3583 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
213class TensorId(LowerCaseIdentifier): 214 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 215 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 216 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
229class AxisId(LowerCaseIdentifier): 230 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 231 Annotated[ 232 LowerCaseIdentifierAnno, 233 MaxLen(16), 234 AfterValidator(_normalize_axis_id), 235 ] 236 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize
.
280class ParameterizedSize(Node): 281 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 282 283 - **min** and **step** are given by the model description. 284 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 285 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 286 This allows to adjust the axis size more generically. 287 """ 288 289 N: ClassVar[Type[int]] = ParameterizedSize_N 290 """Positive integer to parameterize this axis""" 291 292 min: Annotated[int, Gt(0)] 293 step: Annotated[int, Gt(0)] 294 295 def validate_size(self, size: int) -> int: 296 if size < self.min: 297 raise ValueError(f"size {size} < {self.min}") 298 if (size - self.min) % self.step != 0: 299 raise ValueError( 300 f"axis of size {size} is not parameterized by `min + n*step` =" 301 + f" `{self.min} + n*{self.step}`" 302 ) 303 304 return size 305 306 def get_size(self, n: ParameterizedSize_N) -> int: 307 return self.min + self.step * n 308 309 def get_n(self, s: int) -> ParameterizedSize_N: 310 """return smallest n parameterizing a size greater or equal than `s`""" 311 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size
. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
295 def validate_size(self, size: int) -> int: 296 if size < self.min: 297 raise ValueError(f"size {size} < {self.min}") 298 if (size - self.min) % self.step != 0: 299 raise ValueError( 300 f"axis of size {size} is not parameterized by `min + n*step` =" 301 + f" `{self.min} + n*{self.step}`" 302 ) 303 304 return size
314class DataDependentSize(Node): 315 min: Annotated[int, Gt(0)] = 1 316 max: Annotated[Optional[int], Gt(1)] = None 317 318 @model_validator(mode="after") 319 def _validate_max_gt_min(self): 320 if self.max is not None and self.min >= self.max: 321 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 322 323 return self 324 325 def validate_size(self, size: int) -> int: 326 if size < self.min: 327 raise ValueError(f"size {size} < {self.min}") 328 329 if self.max is not None and size > self.max: 330 raise ValueError(f"size {size} > {self.max}") 331 332 return size
335class SizeReference(Node): 336 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 337 338 `axis.size = reference.size * reference.scale / axis.scale + offset` 339 340 Note: 341 1. The axis and the referenced axis need to have the same unit (or no unit). 342 2. Batch axes may not be referenced. 343 3. Fractions are rounded down. 344 4. If the reference axis is `concatenable` the referencing axis is assumed to be 345 `concatenable` as well with the same block order. 346 347 Example: 348 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 349 Let's assume that we want to express the image height h in relation to its width w 350 instead of only accepting input images of exactly 100*49 pixels 351 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 352 353 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 354 >>> h = SpaceInputAxis( 355 ... id=AxisId("h"), 356 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 357 ... unit="millimeter", 358 ... scale=4, 359 ... ) 360 >>> print(h.size.get_size(h, w)) 361 49 362 363 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 364 """ 365 366 tensor_id: TensorId 367 """tensor id of the reference axis""" 368 369 axis_id: AxisId 370 """axis id of the reference axis""" 371 372 offset: int = 0 373 374 def get_size( 375 self, 376 axis: Union[ 377 ChannelAxis, 378 IndexInputAxis, 379 IndexOutputAxis, 380 TimeInputAxis, 381 SpaceInputAxis, 382 TimeOutputAxis, 383 TimeOutputAxisWithHalo, 384 SpaceOutputAxis, 385 SpaceOutputAxisWithHalo, 386 ], 387 ref_axis: Union[ 388 ChannelAxis, 389 IndexInputAxis, 390 IndexOutputAxis, 391 TimeInputAxis, 392 SpaceInputAxis, 393 TimeOutputAxis, 394 TimeOutputAxisWithHalo, 395 SpaceOutputAxis, 396 SpaceOutputAxisWithHalo, 397 ], 398 n: ParameterizedSize_N = 0, 399 ref_size: Optional[int] = None, 400 ): 401 """Compute the concrete size for a given axis and its reference axis. 402 403 Args: 404 axis: The axis this `SizeReference` is the size of. 405 ref_axis: The reference axis to compute the size from. 406 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 407 and no fixed **ref_size** is given, 408 **n** is used to compute the size of the parameterized **ref_axis**. 409 ref_size: Overwrite the reference size instead of deriving it from 410 **ref_axis** 411 (**ref_axis.scale** is still used; any given **n** is ignored). 412 """ 413 assert ( 414 axis.size == self 415 ), "Given `axis.size` is not defined by this `SizeReference`" 416 417 assert ( 418 ref_axis.id == self.axis_id 419 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 420 421 assert axis.unit == ref_axis.unit, ( 422 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 423 f" but {axis.unit}!={ref_axis.unit}" 424 ) 425 if ref_size is None: 426 if isinstance(ref_axis.size, (int, float)): 427 ref_size = ref_axis.size 428 elif isinstance(ref_axis.size, ParameterizedSize): 429 ref_size = ref_axis.size.get_size(n) 430 elif isinstance(ref_axis.size, DataDependentSize): 431 raise ValueError( 432 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 433 ) 434 elif isinstance(ref_axis.size, SizeReference): 435 raise ValueError( 436 "Reference axis referenced in `SizeReference` may not be sized by a" 437 + " `SizeReference` itself." 438 ) 439 else: 440 assert_never(ref_axis.size) 441 442 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 443 444 @staticmethod 445 def _get_unit( 446 axis: Union[ 447 ChannelAxis, 448 IndexInputAxis, 449 IndexOutputAxis, 450 TimeInputAxis, 451 SpaceInputAxis, 452 TimeOutputAxis, 453 TimeOutputAxisWithHalo, 454 SpaceOutputAxis, 455 SpaceOutputAxisWithHalo, 456 ], 457 ): 458 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
374 def get_size( 375 self, 376 axis: Union[ 377 ChannelAxis, 378 IndexInputAxis, 379 IndexOutputAxis, 380 TimeInputAxis, 381 SpaceInputAxis, 382 TimeOutputAxis, 383 TimeOutputAxisWithHalo, 384 SpaceOutputAxis, 385 SpaceOutputAxisWithHalo, 386 ], 387 ref_axis: Union[ 388 ChannelAxis, 389 IndexInputAxis, 390 IndexOutputAxis, 391 TimeInputAxis, 392 SpaceInputAxis, 393 TimeOutputAxis, 394 TimeOutputAxisWithHalo, 395 SpaceOutputAxis, 396 SpaceOutputAxisWithHalo, 397 ], 398 n: ParameterizedSize_N = 0, 399 ref_size: Optional[int] = None, 400 ): 401 """Compute the concrete size for a given axis and its reference axis. 402 403 Args: 404 axis: The axis this `SizeReference` is the size of. 405 ref_axis: The reference axis to compute the size from. 406 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 407 and no fixed **ref_size** is given, 408 **n** is used to compute the size of the parameterized **ref_axis**. 409 ref_size: Overwrite the reference size instead of deriving it from 410 **ref_axis** 411 (**ref_axis.scale** is still used; any given **n** is ignored). 412 """ 413 assert ( 414 axis.size == self 415 ), "Given `axis.size` is not defined by this `SizeReference`" 416 417 assert ( 418 ref_axis.id == self.axis_id 419 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 420 421 assert axis.unit == ref_axis.unit, ( 422 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 423 f" but {axis.unit}!={ref_axis.unit}" 424 ) 425 if ref_size is None: 426 if isinstance(ref_axis.size, (int, float)): 427 ref_size = ref_axis.size 428 elif isinstance(ref_axis.size, ParameterizedSize): 429 ref_size = ref_axis.size.get_size(n) 430 elif isinstance(ref_axis.size, DataDependentSize): 431 raise ValueError( 432 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 433 ) 434 elif isinstance(ref_axis.size, SizeReference): 435 raise ValueError( 436 "Reference axis referenced in `SizeReference` may not be sized by a" 437 + " `SizeReference` itself." 438 ) 439 else: 440 assert_never(ref_axis.size) 441 442 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
468class WithHalo(Node): 469 halo: Annotated[int, Ge(1)] 470 """The halo should be cropped from the output tensor to avoid boundary effects. 471 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 472 To document a halo that is already cropped by the model use `size.offset` instead.""" 473 474 size: Annotated[ 475 SizeReference, 476 Field( 477 examples=[ 478 10, 479 SizeReference( 480 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 481 ).model_dump(mode="json"), 482 ] 483 ), 484 ] 485 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
491class BatchAxis(AxisBase): 492 implemented_type: ClassVar[Literal["batch"]] = "batch" 493 if TYPE_CHECKING: 494 type: Literal["batch"] = "batch" 495 else: 496 type: Literal["batch"] 497 498 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 499 size: Optional[Literal[1]] = None 500 """The batch size may be fixed to 1, 501 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 502 503 @property 504 def scale(self): 505 return 1.0 506 507 @property 508 def concatenable(self): 509 return True 510 511 @property 512 def unit(self): 513 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Inherited Members
516class ChannelAxis(AxisBase): 517 implemented_type: ClassVar[Literal["channel"]] = "channel" 518 if TYPE_CHECKING: 519 type: Literal["channel"] = "channel" 520 else: 521 type: Literal["channel"] 522 523 id: NonBatchAxisId = AxisId("channel") 524 channel_names: NotEmpty[List[Identifier]] 525 526 @property 527 def size(self) -> int: 528 return len(self.channel_names) 529 530 @property 531 def concatenable(self): 532 return False 533 534 @property 535 def scale(self) -> float: 536 return 1.0 537 538 @property 539 def unit(self): 540 return None
Inherited Members
543class IndexAxisBase(AxisBase): 544 implemented_type: ClassVar[Literal["index"]] = "index" 545 if TYPE_CHECKING: 546 type: Literal["index"] = "index" 547 else: 548 type: Literal["index"] 549 550 id: NonBatchAxisId = AxisId("index") 551 552 @property 553 def scale(self) -> float: 554 return 1.0 555 556 @property 557 def unit(self): 558 return None
Inherited Members
581class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 582 concatenable: bool = False 583 """If a model has a `concatenable` input axis, it can be processed blockwise, 584 splitting a longer sample axis into blocks matching its input tensor description. 585 Output axes are concatenable if they have a `SizeReference` to a concatenable 586 input axis. 587 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
590class IndexOutputAxis(IndexAxisBase): 591 size: Annotated[ 592 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 593 Field( 594 examples=[ 595 10, 596 SizeReference( 597 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 598 ).model_dump(mode="json"), 599 ] 600 ), 601 ] 602 """The size/length of this axis can be specified as 603 - fixed integer 604 - reference to another axis with an optional offset (`SizeReference`) 605 - data dependent size using `DataDependentSize` (size is only known after model inference) 606 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Inherited Members
609class TimeAxisBase(AxisBase): 610 implemented_type: ClassVar[Literal["time"]] = "time" 611 if TYPE_CHECKING: 612 type: Literal["time"] = "time" 613 else: 614 type: Literal["time"] 615 616 id: NonBatchAxisId = AxisId("time") 617 unit: Optional[TimeUnit] = None 618 scale: Annotated[float, Gt(0)] = 1.0
Inherited Members
621class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 622 concatenable: bool = False 623 """If a model has a `concatenable` input axis, it can be processed blockwise, 624 splitting a longer sample axis into blocks matching its input tensor description. 625 Output axes are concatenable if they have a `SizeReference` to a concatenable 626 input axis. 627 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
630class SpaceAxisBase(AxisBase): 631 implemented_type: ClassVar[Literal["space"]] = "space" 632 if TYPE_CHECKING: 633 type: Literal["space"] = "space" 634 else: 635 type: Literal["space"] 636 637 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 638 unit: Optional[SpaceUnit] = None 639 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Inherited Members
642class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 643 concatenable: bool = False 644 """If a model has a `concatenable` input axis, it can be processed blockwise, 645 splitting a longer sample axis into blocks matching its input tensor description. 646 Output axes are concatenable if they have a `SizeReference` to a concatenable 647 input axis. 648 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
intended for isinstance comparisons in py<3.10
Inherited Members
Inherited Members
Inherited Members
Inherited Members
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
770class NominalOrOrdinalDataDescr(Node): 771 values: TVs 772 """A fixed set of nominal or an ascending sequence of ordinal values. 773 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 774 String `values` are interpreted as labels for tensor values 0, ..., N. 775 Note: as YAML 1.2 does not natively support a "set" datatype, 776 nominal values should be given as a sequence (aka list/array) as well. 777 """ 778 779 type: Annotated[ 780 NominalOrOrdinalDType, 781 Field( 782 examples=[ 783 "float32", 784 "uint8", 785 "uint16", 786 "int64", 787 "bool", 788 ], 789 ), 790 ] = "uint8" 791 792 @model_validator(mode="after") 793 def _validate_values_match_type( 794 self, 795 ) -> Self: 796 incompatible: List[Any] = [] 797 for v in self.values: 798 if self.type == "bool": 799 if not isinstance(v, bool): 800 incompatible.append(v) 801 elif self.type in DTYPE_LIMITS: 802 if ( 803 isinstance(v, (int, float)) 804 and ( 805 v < DTYPE_LIMITS[self.type].min 806 or v > DTYPE_LIMITS[self.type].max 807 ) 808 or (isinstance(v, str) and "uint" not in self.type) 809 or (isinstance(v, float) and "int" in self.type) 810 ): 811 incompatible.append(v) 812 else: 813 incompatible.append(v) 814 815 if len(incompatible) == 5: 816 incompatible.append("...") 817 break 818 819 if incompatible: 820 raise ValueError( 821 f"data type '{self.type}' incompatible with values {incompatible}" 822 ) 823 824 return self 825 826 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 827 828 @property 829 def range(self): 830 if isinstance(self.values[0], str): 831 return 0, len(self.values) - 1 832 else: 833 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
850class IntervalOrRatioDataDescr(Node): 851 type: Annotated[ # todo: rename to dtype 852 IntervalOrRatioDType, 853 Field( 854 examples=["float32", "float64", "uint8", "uint16"], 855 ), 856 ] = "float32" 857 range: Tuple[Optional[float], Optional[float]] = ( 858 None, 859 None, 860 ) 861 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 862 `None` corresponds to min/max of what can be expressed by **type**.""" 863 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 864 scale: float = 1.0 865 """Scale for data on an interval (or ratio) scale.""" 866 offset: Optional[float] = None 867 """Offset for data on a ratio scale."""
processing base class
877class BinarizeKwargs(ProcessingKwargs): 878 """key word arguments for `BinarizeDescr`""" 879 880 threshold: float 881 """The fixed threshold"""
key word arguments for BinarizeDescr
884class BinarizeAlongAxisKwargs(ProcessingKwargs): 885 """key word arguments for `BinarizeDescr`""" 886 887 threshold: NotEmpty[List[float]] 888 """The fixed threshold values along `axis`""" 889 890 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 891 """The `threshold` axis"""
key word arguments for BinarizeDescr
894class BinarizeDescr(ProcessingDescrBase): 895 """Binarize the tensor with a fixed threshold. 896 897 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 898 will be set to one, values below the threshold to zero. 899 900 Examples: 901 - in YAML 902 ```yaml 903 postprocessing: 904 - id: binarize 905 kwargs: 906 axis: 'channel' 907 threshold: [0.25, 0.5, 0.75] 908 ``` 909 - in Python: 910 >>> postprocessing = [BinarizeDescr( 911 ... kwargs=BinarizeAlongAxisKwargs( 912 ... axis=AxisId('channel'), 913 ... threshold=[0.25, 0.5, 0.75], 914 ... ) 915 ... )] 916 """ 917 918 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 919 if TYPE_CHECKING: 920 id: Literal["binarize"] = "binarize" 921 else: 922 id: Literal["binarize"] 923 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
926class ClipDescr(ProcessingDescrBase): 927 """Set tensor values below min to min and above max to max. 928 929 See `ScaleRangeDescr` for examples. 930 """ 931 932 implemented_id: ClassVar[Literal["clip"]] = "clip" 933 if TYPE_CHECKING: 934 id: Literal["clip"] = "clip" 935 else: 936 id: Literal["clip"] 937 938 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
941class EnsureDtypeKwargs(ProcessingKwargs): 942 """key word arguments for `EnsureDtypeDescr`""" 943 944 dtype: Literal[ 945 "float32", 946 "float64", 947 "uint8", 948 "int8", 949 "uint16", 950 "int16", 951 "uint32", 952 "int32", 953 "uint64", 954 "int64", 955 "bool", 956 ]
key word arguments for EnsureDtypeDescr
959class EnsureDtypeDescr(ProcessingDescrBase): 960 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 961 962 This can for example be used to ensure the inner neural network model gets a 963 different input tensor data type than the fully described bioimage.io model does. 964 965 Examples: 966 The described bioimage.io model (incl. preprocessing) accepts any 967 float32-compatible tensor, normalizes it with percentiles and clipping and then 968 casts it to uint8, which is what the neural network in this example expects. 969 - in YAML 970 ```yaml 971 inputs: 972 - data: 973 type: float32 # described bioimage.io model is compatible with any float32 input tensor 974 preprocessing: 975 - id: scale_range 976 kwargs: 977 axes: ['y', 'x'] 978 max_percentile: 99.8 979 min_percentile: 5.0 980 - id: clip 981 kwargs: 982 min: 0.0 983 max: 1.0 984 - id: ensure_dtype 985 kwargs: 986 dtype: uint8 987 ``` 988 - in Python: 989 >>> preprocessing = [ 990 ... ScaleRangeDescr( 991 ... kwargs=ScaleRangeKwargs( 992 ... axes= (AxisId('y'), AxisId('x')), 993 ... max_percentile= 99.8, 994 ... min_percentile= 5.0, 995 ... ) 996 ... ), 997 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 998 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 999 ... ] 1000 """ 1001 1002 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1003 if TYPE_CHECKING: 1004 id: Literal["ensure_dtype"] = "ensure_dtype" 1005 else: 1006 id: Literal["ensure_dtype"] 1007 1008 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
1011class ScaleLinearKwargs(ProcessingKwargs): 1012 """Key word arguments for `ScaleLinearDescr`""" 1013 1014 gain: float = 1.0 1015 """multiplicative factor""" 1016 1017 offset: float = 0.0 1018 """additive term""" 1019 1020 @model_validator(mode="after") 1021 def _validate(self) -> Self: 1022 if self.gain == 1.0 and self.offset == 0.0: 1023 raise ValueError( 1024 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1025 + " != 0.0." 1026 ) 1027 1028 return self
Key word arguments for ScaleLinearDescr
1031class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1032 """Key word arguments for `ScaleLinearDescr`""" 1033 1034 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1035 """The axis of gain and offset values.""" 1036 1037 gain: Union[float, NotEmpty[List[float]]] = 1.0 1038 """multiplicative factor""" 1039 1040 offset: Union[float, NotEmpty[List[float]]] = 0.0 1041 """additive term""" 1042 1043 @model_validator(mode="after") 1044 def _validate(self) -> Self: 1045 1046 if isinstance(self.gain, list): 1047 if isinstance(self.offset, list): 1048 if len(self.gain) != len(self.offset): 1049 raise ValueError( 1050 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1051 ) 1052 else: 1053 self.offset = [float(self.offset)] * len(self.gain) 1054 elif isinstance(self.offset, list): 1055 self.gain = [float(self.gain)] * len(self.offset) 1056 else: 1057 raise ValueError( 1058 "Do not specify an `axis` for scalar gain and offset values." 1059 ) 1060 1061 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1062 raise ValueError( 1063 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1064 + " != 0.0." 1065 ) 1066 1067 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
1070class ScaleLinearDescr(ProcessingDescrBase): 1071 """Fixed linear scaling. 1072 1073 Examples: 1074 1. Scale with scalar gain and offset 1075 - in YAML 1076 ```yaml 1077 preprocessing: 1078 - id: scale_linear 1079 kwargs: 1080 gain: 2.0 1081 offset: 3.0 1082 ``` 1083 - in Python: 1084 >>> preprocessing = [ 1085 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1086 ... ] 1087 1088 2. Independent scaling along an axis 1089 - in YAML 1090 ```yaml 1091 preprocessing: 1092 - id: scale_linear 1093 kwargs: 1094 axis: 'channel' 1095 gain: [1.0, 2.0, 3.0] 1096 ``` 1097 - in Python: 1098 >>> preprocessing = [ 1099 ... ScaleLinearDescr( 1100 ... kwargs=ScaleLinearAlongAxisKwargs( 1101 ... axis=AxisId("channel"), 1102 ... gain=[1.0, 2.0, 3.0], 1103 ... ) 1104 ... ) 1105 ... ] 1106 1107 """ 1108 1109 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1110 if TYPE_CHECKING: 1111 id: Literal["scale_linear"] = "scale_linear" 1112 else: 1113 id: Literal["scale_linear"] 1114 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
1117class SigmoidDescr(ProcessingDescrBase): 1118 """The logistic sigmoid funciton, a.k.a. expit function. 1119 1120 Examples: 1121 - in YAML 1122 ```yaml 1123 postprocessing: 1124 - id: sigmoid 1125 ``` 1126 - in Python: 1127 >>> postprocessing = [SigmoidDescr()] 1128 """ 1129 1130 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1131 if TYPE_CHECKING: 1132 id: Literal["sigmoid"] = "sigmoid" 1133 else: 1134 id: Literal["sigmoid"] 1135 1136 @property 1137 def kwargs(self) -> ProcessingKwargs: 1138 """empty kwargs""" 1139 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1142class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1143 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1144 1145 mean: float 1146 """The mean value to normalize with.""" 1147 1148 std: Annotated[float, Ge(1e-6)] 1149 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
1152class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1153 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1154 1155 mean: NotEmpty[List[float]] 1156 """The mean value(s) to normalize with.""" 1157 1158 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1159 """The standard deviation value(s) to normalize with. 1160 Size must match `mean` values.""" 1161 1162 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1163 """The axis of the mean/std values to normalize each entry along that dimension 1164 separately.""" 1165 1166 @model_validator(mode="after") 1167 def _mean_and_std_match(self) -> Self: 1168 if len(self.mean) != len(self.std): 1169 raise ValueError( 1170 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1171 + " must match." 1172 ) 1173 1174 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
1177class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1178 """Subtract a given mean and divide by the standard deviation. 1179 1180 Normalize with fixed, precomputed values for 1181 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1182 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1183 axes. 1184 1185 Examples: 1186 1. scalar value for whole tensor 1187 - in YAML 1188 ```yaml 1189 preprocessing: 1190 - id: fixed_zero_mean_unit_variance 1191 kwargs: 1192 mean: 103.5 1193 std: 13.7 1194 ``` 1195 - in Python 1196 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1197 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1198 ... )] 1199 1200 2. independently along an axis 1201 - in YAML 1202 ```yaml 1203 preprocessing: 1204 - id: fixed_zero_mean_unit_variance 1205 kwargs: 1206 axis: channel 1207 mean: [101.5, 102.5, 103.5] 1208 std: [11.7, 12.7, 13.7] 1209 ``` 1210 - in Python 1211 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1212 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1213 ... axis=AxisId("channel"), 1214 ... mean=[101.5, 102.5, 103.5], 1215 ... std=[11.7, 12.7, 13.7], 1216 ... ) 1217 ... )] 1218 """ 1219 1220 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1221 "fixed_zero_mean_unit_variance" 1222 ) 1223 if TYPE_CHECKING: 1224 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1225 else: 1226 id: Literal["fixed_zero_mean_unit_variance"] 1227 1228 kwargs: Union[ 1229 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1230 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
1233class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1234 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1235 1236 axes: Annotated[ 1237 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1238 ] = None 1239 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1240 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1241 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1242 To normalize each sample independently leave out the 'batch' axis. 1243 Default: Scale all axes jointly.""" 1244 1245 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1246 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
1249class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1250 """Subtract mean and divide by variance. 1251 1252 Examples: 1253 Subtract tensor mean and variance 1254 - in YAML 1255 ```yaml 1256 preprocessing: 1257 - id: zero_mean_unit_variance 1258 ``` 1259 - in Python 1260 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1261 """ 1262 1263 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1264 "zero_mean_unit_variance" 1265 ) 1266 if TYPE_CHECKING: 1267 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1268 else: 1269 id: Literal["zero_mean_unit_variance"] 1270 1271 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1272 default_factory=ZeroMeanUnitVarianceKwargs 1273 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1276class ScaleRangeKwargs(ProcessingKwargs): 1277 """key word arguments for `ScaleRangeDescr` 1278 1279 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1280 this processing step normalizes data to the [0, 1] intervall. 1281 For other percentiles the normalized values will partially be outside the [0, 1] 1282 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1283 normalized values to a range. 1284 """ 1285 1286 axes: Annotated[ 1287 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1288 ] = None 1289 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1290 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1291 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1292 To normalize samples independently, leave out the "batch" axis. 1293 Default: Scale all axes jointly.""" 1294 1295 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1296 """The lower percentile used to determine the value to align with zero.""" 1297 1298 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1299 """The upper percentile used to determine the value to align with one. 1300 Has to be bigger than `min_percentile`. 1301 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1302 accepting percentiles specified in the range 0.0 to 1.0.""" 1303 1304 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1305 """Epsilon for numeric stability. 1306 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1307 with `v_lower,v_upper` values at the respective percentiles.""" 1308 1309 reference_tensor: Optional[TensorId] = None 1310 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1311 For any tensor in `inputs` only input tensor references are allowed.""" 1312 1313 @field_validator("max_percentile", mode="after") 1314 @classmethod 1315 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1316 if (min_p := info.data["min_percentile"]) >= value: 1317 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1318 1319 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1313 @field_validator("max_percentile", mode="after") 1314 @classmethod 1315 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1316 if (min_p := info.data["min_percentile"]) >= value: 1317 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1318 1319 return value
1322class ScaleRangeDescr(ProcessingDescrBase): 1323 """Scale with percentiles. 1324 1325 Examples: 1326 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1327 - in YAML 1328 ```yaml 1329 preprocessing: 1330 - id: scale_range 1331 kwargs: 1332 axes: ['y', 'x'] 1333 max_percentile: 99.8 1334 min_percentile: 5.0 1335 ``` 1336 - in Python 1337 >>> preprocessing = [ 1338 ... ScaleRangeDescr( 1339 ... kwargs=ScaleRangeKwargs( 1340 ... axes= (AxisId('y'), AxisId('x')), 1341 ... max_percentile= 99.8, 1342 ... min_percentile= 5.0, 1343 ... ) 1344 ... ), 1345 ... ClipDescr( 1346 ... kwargs=ClipKwargs( 1347 ... min=0.0, 1348 ... max=1.0, 1349 ... ) 1350 ... ), 1351 ... ] 1352 1353 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1354 - in YAML 1355 ```yaml 1356 preprocessing: 1357 - id: scale_range 1358 kwargs: 1359 axes: ['y', 'x'] 1360 max_percentile: 99.8 1361 min_percentile: 5.0 1362 - id: scale_range 1363 - id: clip 1364 kwargs: 1365 min: 0.0 1366 max: 1.0 1367 ``` 1368 - in Python 1369 >>> preprocessing = [ScaleRangeDescr( 1370 ... kwargs=ScaleRangeKwargs( 1371 ... axes= (AxisId('y'), AxisId('x')), 1372 ... max_percentile= 99.8, 1373 ... min_percentile= 5.0, 1374 ... ) 1375 ... )] 1376 1377 """ 1378 1379 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1380 if TYPE_CHECKING: 1381 id: Literal["scale_range"] = "scale_range" 1382 else: 1383 id: Literal["scale_range"] 1384 kwargs: ScaleRangeKwargs
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
1387class ScaleMeanVarianceKwargs(ProcessingKwargs): 1388 """key word arguments for `ScaleMeanVarianceKwargs`""" 1389 1390 reference_tensor: TensorId 1391 """Name of tensor to match.""" 1392 1393 axes: Annotated[ 1394 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1395 ] = None 1396 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1397 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1398 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1399 To normalize samples independently, leave out the 'batch' axis. 1400 Default: Scale all axes jointly.""" 1401 1402 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1403 """Epsilon for numeric stability: 1404 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
1407class ScaleMeanVarianceDescr(ProcessingDescrBase): 1408 """Scale a tensor's data distribution to match another tensor's mean/std. 1409 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1410 """ 1411 1412 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1413 if TYPE_CHECKING: 1414 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1415 else: 1416 id: Literal["scale_mean_variance"] 1417 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
1451class TensorDescrBase(Node, Generic[IO_AxisT]): 1452 id: TensorId 1453 """Tensor id. No duplicates are allowed.""" 1454 1455 description: Annotated[str, MaxLen(128)] = "" 1456 """free text description""" 1457 1458 axes: NotEmpty[Sequence[IO_AxisT]] 1459 """tensor axes""" 1460 1461 @property 1462 def shape(self): 1463 return tuple(a.size for a in self.axes) 1464 1465 @field_validator("axes", mode="after", check_fields=False) 1466 @classmethod 1467 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1468 batch_axes = [a for a in axes if a.type == "batch"] 1469 if len(batch_axes) > 1: 1470 raise ValueError( 1471 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1472 ) 1473 1474 seen_ids: Set[AxisId] = set() 1475 duplicate_axes_ids: Set[AxisId] = set() 1476 for a in axes: 1477 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1478 1479 if duplicate_axes_ids: 1480 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1481 1482 return axes 1483 1484 test_tensor: FileDescr 1485 """An example tensor to use for testing. 1486 Using the model with the test input tensors is expected to yield the test output tensors. 1487 Each test tensor has be a an ndarray in the 1488 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1489 The file extension must be '.npy'.""" 1490 1491 sample_tensor: Optional[FileDescr] = None 1492 """A sample tensor to illustrate a possible input/output for the model, 1493 The sample image primarily serves to inform a human user about an example use case 1494 and is typically stored as .hdf5, .png or .tiff. 1495 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1496 (numpy's `.npy` format is not supported). 1497 The image dimensionality has to match the number of axes specified in this tensor description. 1498 """ 1499 1500 @model_validator(mode="after") 1501 def _validate_sample_tensor(self) -> Self: 1502 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1503 return self 1504 1505 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1506 tensor: NDArray[Any] = imread( 1507 local.path.read_bytes(), 1508 extension=PurePosixPath(local.original_file_name).suffix, 1509 ) 1510 n_dims = len(tensor.squeeze().shape) 1511 n_dims_min = n_dims_max = len(self.axes) 1512 1513 for a in self.axes: 1514 if isinstance(a, BatchAxis): 1515 n_dims_min -= 1 1516 elif isinstance(a.size, int): 1517 if a.size == 1: 1518 n_dims_min -= 1 1519 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1520 if a.size.min == 1: 1521 n_dims_min -= 1 1522 elif isinstance(a.size, SizeReference): 1523 if a.size.offset < 2: 1524 # size reference may result in singleton axis 1525 n_dims_min -= 1 1526 else: 1527 assert_never(a.size) 1528 1529 n_dims_min = max(0, n_dims_min) 1530 if n_dims < n_dims_min or n_dims > n_dims_max: 1531 raise ValueError( 1532 f"Expected sample tensor to have {n_dims_min} to" 1533 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1534 ) 1535 1536 return self 1537 1538 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1539 IntervalOrRatioDataDescr() 1540 ) 1541 """Description of the tensor's data values, optionally per channel. 1542 If specified per channel, the data `type` needs to match across channels.""" 1543 1544 @property 1545 def dtype( 1546 self, 1547 ) -> Literal[ 1548 "float32", 1549 "float64", 1550 "uint8", 1551 "int8", 1552 "uint16", 1553 "int16", 1554 "uint32", 1555 "int32", 1556 "uint64", 1557 "int64", 1558 "bool", 1559 ]: 1560 """dtype as specified under `data.type` or `data[i].type`""" 1561 if isinstance(self.data, collections.abc.Sequence): 1562 return self.data[0].type 1563 else: 1564 return self.data.type 1565 1566 @field_validator("data", mode="after") 1567 @classmethod 1568 def _check_data_type_across_channels( 1569 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1570 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1571 if not isinstance(value, list): 1572 return value 1573 1574 dtypes = {t.type for t in value} 1575 if len(dtypes) > 1: 1576 raise ValueError( 1577 "Tensor data descriptions per channel need to agree in their data" 1578 + f" `type`, but found {dtypes}." 1579 ) 1580 1581 return value 1582 1583 @model_validator(mode="after") 1584 def _check_data_matches_channelaxis(self) -> Self: 1585 if not isinstance(self.data, (list, tuple)): 1586 return self 1587 1588 for a in self.axes: 1589 if isinstance(a, ChannelAxis): 1590 size = a.size 1591 assert isinstance(size, int) 1592 break 1593 else: 1594 return self 1595 1596 if len(self.data) != size: 1597 raise ValueError( 1598 f"Got tensor data descriptions for {len(self.data)} channels, but" 1599 + f" '{a.id}' axis has size {size}." 1600 ) 1601 1602 return self 1603 1604 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1605 if len(array.shape) != len(self.axes): 1606 raise ValueError( 1607 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1608 + f" incompatible with {len(self.axes)} axes." 1609 ) 1610 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1544 @property 1545 def dtype( 1546 self, 1547 ) -> Literal[ 1548 "float32", 1549 "float64", 1550 "uint8", 1551 "int8", 1552 "uint16", 1553 "int16", 1554 "uint32", 1555 "int32", 1556 "uint64", 1557 "int64", 1558 "bool", 1559 ]: 1560 """dtype as specified under `data.type` or `data[i].type`""" 1561 if isinstance(self.data, collections.abc.Sequence): 1562 return self.data[0].type 1563 else: 1564 return self.data.type
dtype as specified under data.type
or data[i].type
1604 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1605 if len(array.shape) != len(self.axes): 1606 raise ValueError( 1607 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1608 + f" incompatible with {len(self.axes)} axes." 1609 ) 1610 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1613class InputTensorDescr(TensorDescrBase[InputAxis]): 1614 id: TensorId = TensorId("input") 1615 """Input tensor id. 1616 No duplicates are allowed across all inputs and outputs.""" 1617 1618 optional: bool = False 1619 """indicates that this tensor may be `None`""" 1620 1621 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1622 """Description of how this input should be preprocessed. 1623 1624 notes: 1625 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1626 to ensure an input tensor's data type matches the input tensor's data description. 1627 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1628 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1629 changing the data type. 1630 """ 1631 1632 @model_validator(mode="after") 1633 def _validate_preprocessing_kwargs(self) -> Self: 1634 axes_ids = [a.id for a in self.axes] 1635 for p in self.preprocessing: 1636 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1637 if kwargs_axes is None: 1638 continue 1639 1640 if not isinstance(kwargs_axes, collections.abc.Sequence): 1641 raise ValueError( 1642 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1643 ) 1644 1645 if any(a not in axes_ids for a in kwargs_axes): 1646 raise ValueError( 1647 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1648 ) 1649 1650 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1651 dtype = self.data.type 1652 else: 1653 dtype = self.data[0].type 1654 1655 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1656 if not self.preprocessing or not isinstance( 1657 self.preprocessing[0], EnsureDtypeDescr 1658 ): 1659 self.preprocessing.insert( 1660 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1661 ) 1662 1663 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1664 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1665 self.preprocessing.append( 1666 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1667 ) 1668 1669 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
Inherited Members
1672def convert_axes( 1673 axes: str, 1674 *, 1675 shape: Union[ 1676 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1677 ], 1678 tensor_type: Literal["input", "output"], 1679 halo: Optional[Sequence[int]], 1680 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1681): 1682 ret: List[AnyAxis] = [] 1683 for i, a in enumerate(axes): 1684 axis_type = _AXIS_TYPE_MAP.get(a, a) 1685 if axis_type == "batch": 1686 ret.append(BatchAxis()) 1687 continue 1688 1689 scale = 1.0 1690 if isinstance(shape, _ParameterizedInputShape_v0_4): 1691 if shape.step[i] == 0: 1692 size = shape.min[i] 1693 else: 1694 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1695 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1696 ref_t = str(shape.reference_tensor) 1697 if ref_t.count(".") == 1: 1698 t_id, orig_a_id = ref_t.split(".") 1699 else: 1700 t_id = ref_t 1701 orig_a_id = a 1702 1703 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1704 if not (orig_scale := shape.scale[i]): 1705 # old way to insert a new axis dimension 1706 size = int(2 * shape.offset[i]) 1707 else: 1708 scale = 1 / orig_scale 1709 if axis_type in ("channel", "index"): 1710 # these axes no longer have a scale 1711 offset_from_scale = orig_scale * size_refs.get( 1712 _TensorName_v0_4(t_id), {} 1713 ).get(orig_a_id, 0) 1714 else: 1715 offset_from_scale = 0 1716 size = SizeReference( 1717 tensor_id=TensorId(t_id), 1718 axis_id=AxisId(a_id), 1719 offset=int(offset_from_scale + 2 * shape.offset[i]), 1720 ) 1721 else: 1722 size = shape[i] 1723 1724 if axis_type == "time": 1725 if tensor_type == "input": 1726 ret.append(TimeInputAxis(size=size, scale=scale)) 1727 else: 1728 assert not isinstance(size, ParameterizedSize) 1729 if halo is None: 1730 ret.append(TimeOutputAxis(size=size, scale=scale)) 1731 else: 1732 assert not isinstance(size, int) 1733 ret.append( 1734 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1735 ) 1736 1737 elif axis_type == "index": 1738 if tensor_type == "input": 1739 ret.append(IndexInputAxis(size=size)) 1740 else: 1741 if isinstance(size, ParameterizedSize): 1742 size = DataDependentSize(min=size.min) 1743 1744 ret.append(IndexOutputAxis(size=size)) 1745 elif axis_type == "channel": 1746 assert not isinstance(size, ParameterizedSize) 1747 if isinstance(size, SizeReference): 1748 warnings.warn( 1749 "Conversion of channel size from an implicit output shape may be" 1750 + " wrong" 1751 ) 1752 ret.append( 1753 ChannelAxis( 1754 channel_names=[ 1755 Identifier(f"channel{i}") for i in range(size.offset) 1756 ] 1757 ) 1758 ) 1759 else: 1760 ret.append( 1761 ChannelAxis( 1762 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1763 ) 1764 ) 1765 elif axis_type == "space": 1766 if tensor_type == "input": 1767 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1768 else: 1769 assert not isinstance(size, ParameterizedSize) 1770 if halo is None or halo[i] == 0: 1771 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1772 elif isinstance(size, int): 1773 raise NotImplementedError( 1774 f"output axis with halo and fixed size (here {size}) not allowed" 1775 ) 1776 else: 1777 ret.append( 1778 SpaceOutputAxisWithHalo( 1779 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1780 ) 1781 ) 1782 1783 return ret
1943class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1944 id: TensorId = TensorId("output") 1945 """Output tensor id. 1946 No duplicates are allowed across all inputs and outputs.""" 1947 1948 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1949 """Description of how this output should be postprocessed. 1950 1951 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1952 If not given this is added to cast to this tensor's `data.type`. 1953 """ 1954 1955 @model_validator(mode="after") 1956 def _validate_postprocessing_kwargs(self) -> Self: 1957 axes_ids = [a.id for a in self.axes] 1958 for p in self.postprocessing: 1959 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1960 if kwargs_axes is None: 1961 continue 1962 1963 if not isinstance(kwargs_axes, collections.abc.Sequence): 1964 raise ValueError( 1965 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1966 ) 1967 1968 if any(a not in axes_ids for a in kwargs_axes): 1969 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1970 1971 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1972 dtype = self.data.type 1973 else: 1974 dtype = self.data[0].type 1975 1976 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1977 if not self.postprocessing or not isinstance( 1978 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1979 ): 1980 self.postprocessing.append( 1981 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1982 ) 1983 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
Inherited Members
2033def validate_tensors( 2034 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2035 tensor_origin: Literal[ 2036 "test_tensor" 2037 ], # for more precise error messages, e.g. 'test_tensor' 2038): 2039 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2040 2041 def e_msg(d: TensorDescr): 2042 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2043 2044 for descr, array in tensors.values(): 2045 try: 2046 axis_sizes = descr.get_axis_sizes_for_array(array) 2047 except ValueError as e: 2048 raise ValueError(f"{e_msg(descr)} {e}") 2049 else: 2050 all_tensor_axes[descr.id] = { 2051 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2052 } 2053 2054 for descr, array in tensors.values(): 2055 if descr.dtype in ("float32", "float64"): 2056 invalid_test_tensor_dtype = array.dtype.name not in ( 2057 "float32", 2058 "float64", 2059 "uint8", 2060 "int8", 2061 "uint16", 2062 "int16", 2063 "uint32", 2064 "int32", 2065 "uint64", 2066 "int64", 2067 ) 2068 else: 2069 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2070 2071 if invalid_test_tensor_dtype: 2072 raise ValueError( 2073 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2074 + f" match described dtype '{descr.dtype}'" 2075 ) 2076 2077 if array.min() > -1e-4 and array.max() < 1e-4: 2078 raise ValueError( 2079 "Output values are too small for reliable testing." 2080 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2081 ) 2082 2083 for a in descr.axes: 2084 actual_size = all_tensor_axes[descr.id][a.id][1] 2085 if a.size is None: 2086 continue 2087 2088 if isinstance(a.size, int): 2089 if actual_size != a.size: 2090 raise ValueError( 2091 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2092 + f"has incompatible size {actual_size}, expected {a.size}" 2093 ) 2094 elif isinstance(a.size, ParameterizedSize): 2095 _ = a.size.validate_size(actual_size) 2096 elif isinstance(a.size, DataDependentSize): 2097 _ = a.size.validate_size(actual_size) 2098 elif isinstance(a.size, SizeReference): 2099 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2100 if ref_tensor_axes is None: 2101 raise ValueError( 2102 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2103 + f" reference '{a.size.tensor_id}'" 2104 ) 2105 2106 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2107 if ref_axis is None or ref_size is None: 2108 raise ValueError( 2109 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2110 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2111 ) 2112 2113 if a.unit != ref_axis.unit: 2114 raise ValueError( 2115 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2116 + " axis and reference axis to have the same `unit`, but" 2117 + f" {a.unit}!={ref_axis.unit}" 2118 ) 2119 2120 if actual_size != ( 2121 expected_size := ( 2122 ref_size * ref_axis.scale / a.scale + a.size.offset 2123 ) 2124 ): 2125 raise ValueError( 2126 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2127 + f" {actual_size} invalid for referenced size {ref_size};" 2128 + f" expected {expected_size}" 2129 ) 2130 else: 2131 assert_never(a.size)
2134class EnvironmentFileDescr(FileDescr): 2135 source: Annotated[ 2136 ImportantFileSource, 2137 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2138 Field( 2139 examples=["environment.yaml"], 2140 ), 2141 ] 2142 """∈📦 Conda environment file. 2143 Allows to specify custom dependencies, see conda docs: 2144 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2145 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2146 """
∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:
Inherited Members
2161class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2162 import_from: str 2163 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Inherited Members
2229class WeightsEntryDescrBase(FileDescr): 2230 type: ClassVar[WeightsFormat] 2231 weights_format_name: ClassVar[str] # human readable 2232 2233 source: ImportantFileSource 2234 """∈📦 The weights file.""" 2235 2236 authors: Optional[List[Author]] = None 2237 """Authors 2238 Either the person(s) that have trained this model resulting in the original weights file. 2239 (If this is the initial weights entry, i.e. it does not have a `parent`) 2240 Or the person(s) who have converted the weights to this weights format. 2241 (If this is a child weight, i.e. it has a `parent` field) 2242 """ 2243 2244 parent: Annotated[ 2245 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2246 ] = None 2247 """The source weights these weights were converted from. 2248 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2249 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2250 All weight entries except one (the initial set of weights resulting from training the model), 2251 need to have this field.""" 2252 2253 comment: str = "" 2254 """A comment about this weights entry, for example how these weights were created.""" 2255 2256 @model_validator(mode="after") 2257 def check_parent_is_not_self(self) -> Self: 2258 if self.type == self.parent: 2259 raise ValueError("Weights entry can't be it's own parent.") 2260 2261 return self
∈📦 The weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
2264class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2265 type = "keras_hdf5" 2266 weights_format_name: ClassVar[str] = "Keras HDF5" 2267 tensorflow_version: Version 2268 """TensorFlow version used to create these weights."""
TensorFlow version used to create these weights.
Inherited Members
2271class OnnxWeightsDescr(WeightsEntryDescrBase): 2272 type = "onnx" 2273 weights_format_name: ClassVar[str] = "ONNX" 2274 opset_version: Annotated[int, Ge(7)] 2275 """ONNX opset version"""
Inherited Members
2278class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2279 type = "pytorch_state_dict" 2280 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2281 architecture: ArchitectureDescr 2282 pytorch_version: Version 2283 """Version of the PyTorch library used. 2284 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2285 """ 2286 dependencies: Optional[EnvironmentFileDescr] = None 2287 """Custom depencies beyond pytorch. 2288 The conda environment file should include pytorch and any version pinning has to be compatible with 2289 `pytorch_version`. 2290 """
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch.
The conda environment file should include pytorch and any version pinning has to be compatible with
pytorch_version
.
Inherited Members
2293class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2294 type = "tensorflow_js" 2295 weights_format_name: ClassVar[str] = "Tensorflow.js" 2296 tensorflow_version: Version 2297 """Version of the TensorFlow library used.""" 2298 2299 source: ImportantFileSource 2300 """∈📦 The multi-file weights. 2301 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
2304class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2305 type = "tensorflow_saved_model_bundle" 2306 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2307 tensorflow_version: Version 2308 """Version of the TensorFlow library used.""" 2309 2310 dependencies: Optional[EnvironmentFileDescr] = None 2311 """Custom dependencies beyond tensorflow. 2312 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 2313 2314 source: ImportantFileSource 2315 """∈📦 The multi-file weights. 2316 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow.
Should include tensorflow and any version pinning has to be compatible with tensorflow_version
.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
2319class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2320 type = "torchscript" 2321 weights_format_name: ClassVar[str] = "TorchScript" 2322 pytorch_version: Version 2323 """Version of the PyTorch library used."""
Version of the PyTorch library used.
Inherited Members
2326class WeightsDescr(Node): 2327 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2328 onnx: Optional[OnnxWeightsDescr] = None 2329 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2330 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2331 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2332 None 2333 ) 2334 torchscript: Optional[TorchscriptWeightsDescr] = None 2335 2336 @model_validator(mode="after") 2337 def check_entries(self) -> Self: 2338 entries = {wtype for wtype, entry in self if entry is not None} 2339 2340 if not entries: 2341 raise ValueError("Missing weights entry") 2342 2343 entries_wo_parent = { 2344 wtype 2345 for wtype, entry in self 2346 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2347 } 2348 if len(entries_wo_parent) != 1: 2349 issue_warning( 2350 "Exactly one weights entry may not specify the `parent` field (got" 2351 + " {value}). That entry is considered the original set of model weights." 2352 + " Other weight formats are created through conversion of the orignal or" 2353 + " already converted weights. They have to reference the weights format" 2354 + " they were converted from as their `parent`.", 2355 value=len(entries_wo_parent), 2356 field="weights", 2357 ) 2358 2359 for wtype, entry in self: 2360 if entry is None: 2361 continue 2362 2363 assert hasattr(entry, "type") 2364 assert hasattr(entry, "parent") 2365 assert wtype == entry.type 2366 if ( 2367 entry.parent is not None and entry.parent not in entries 2368 ): # self reference checked for `parent` field 2369 raise ValueError( 2370 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2371 + f" formats: {entries}" 2372 ) 2373 2374 return self 2375 2376 def __getitem__( 2377 self, 2378 key: Literal[ 2379 "keras_hdf5", 2380 "onnx", 2381 "pytorch_state_dict", 2382 "tensorflow_js", 2383 "tensorflow_saved_model_bundle", 2384 "torchscript", 2385 ], 2386 ): 2387 if key == "keras_hdf5": 2388 ret = self.keras_hdf5 2389 elif key == "onnx": 2390 ret = self.onnx 2391 elif key == "pytorch_state_dict": 2392 ret = self.pytorch_state_dict 2393 elif key == "tensorflow_js": 2394 ret = self.tensorflow_js 2395 elif key == "tensorflow_saved_model_bundle": 2396 ret = self.tensorflow_saved_model_bundle 2397 elif key == "torchscript": 2398 ret = self.torchscript 2399 else: 2400 raise KeyError(key) 2401 2402 if ret is None: 2403 raise KeyError(key) 2404 2405 return ret 2406 2407 @property 2408 def available_formats(self): 2409 return { 2410 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2411 **({} if self.onnx is None else {"onnx": self.onnx}), 2412 **( 2413 {} 2414 if self.pytorch_state_dict is None 2415 else {"pytorch_state_dict": self.pytorch_state_dict} 2416 ), 2417 **( 2418 {} 2419 if self.tensorflow_js is None 2420 else {"tensorflow_js": self.tensorflow_js} 2421 ), 2422 **( 2423 {} 2424 if self.tensorflow_saved_model_bundle is None 2425 else { 2426 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2427 } 2428 ), 2429 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2430 } 2431 2432 @property 2433 def missing_formats(self): 2434 return { 2435 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2436 }
2336 @model_validator(mode="after") 2337 def check_entries(self) -> Self: 2338 entries = {wtype for wtype, entry in self if entry is not None} 2339 2340 if not entries: 2341 raise ValueError("Missing weights entry") 2342 2343 entries_wo_parent = { 2344 wtype 2345 for wtype, entry in self 2346 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2347 } 2348 if len(entries_wo_parent) != 1: 2349 issue_warning( 2350 "Exactly one weights entry may not specify the `parent` field (got" 2351 + " {value}). That entry is considered the original set of model weights." 2352 + " Other weight formats are created through conversion of the orignal or" 2353 + " already converted weights. They have to reference the weights format" 2354 + " they were converted from as their `parent`.", 2355 value=len(entries_wo_parent), 2356 field="weights", 2357 ) 2358 2359 for wtype, entry in self: 2360 if entry is None: 2361 continue 2362 2363 assert hasattr(entry, "type") 2364 assert hasattr(entry, "parent") 2365 assert wtype == entry.type 2366 if ( 2367 entry.parent is not None and entry.parent not in entries 2368 ): # self reference checked for `parent` field 2369 raise ValueError( 2370 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2371 + f" formats: {entries}" 2372 ) 2373 2374 return self
2407 @property 2408 def available_formats(self): 2409 return { 2410 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2411 **({} if self.onnx is None else {"onnx": self.onnx}), 2412 **( 2413 {} 2414 if self.pytorch_state_dict is None 2415 else {"pytorch_state_dict": self.pytorch_state_dict} 2416 ), 2417 **( 2418 {} 2419 if self.tensorflow_js is None 2420 else {"tensorflow_js": self.tensorflow_js} 2421 ), 2422 **( 2423 {} 2424 if self.tensorflow_saved_model_bundle is None 2425 else { 2426 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2427 } 2428 ), 2429 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2430 }
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2443class LinkedModel(LinkedResourceBase): 2444 """Reference to a bioimage.io model.""" 2445 2446 id: ModelId 2447 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Inherited Members
2469class ReproducibilityTolerance(Node, extra="allow"): 2470 """Describes what small numerical differences -- if any -- may be tolerated 2471 in the generated output when executing in different environments. 2472 2473 A tensor element *output* is considered mismatched to the **test_tensor** if 2474 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2475 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2476 2477 Motivation: 2478 For testing we can request the respective deep learning frameworks to be as 2479 reproducible as possible by setting seeds and chosing deterministic algorithms, 2480 but differences in operating systems, available hardware and installed drivers 2481 may still lead to numerical differences. 2482 """ 2483 2484 relative_tolerance: RelativeTolerance = 1e-3 2485 """Maximum relative tolerance of reproduced test tensor.""" 2486 2487 absolute_tolerance: AbsoluteTolerance = 1e-4 2488 """Maximum absolute tolerance of reproduced test tensor.""" 2489 2490 mismatched_elements_per_million: MismatchedElementsPerMillion = 0 2491 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2492 2493 output_ids: Sequence[TensorId] = () 2494 """Limits the output tensor IDs these reproducibility details apply to.""" 2495 2496 weights_formats: Sequence[WeightsFormat] = () 2497 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
2500class BioimageioConfig(Node, extra="allow"): 2501 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2502 """Tolerances to allow when reproducing the model's test outputs 2503 from the model's test inputs. 2504 Only the first entry matching tensor id and weights format is considered. 2505 """
Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.
2508class Config(Node, extra="allow"): 2509 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2512class ModelDescr(GenericModelDescrBase): 2513 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2514 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2515 """ 2516 2517 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2518 if TYPE_CHECKING: 2519 format_version: Literal["0.5.4"] = "0.5.4" 2520 else: 2521 format_version: Literal["0.5.4"] 2522 """Version of the bioimage.io model description specification used. 2523 When creating a new model always use the latest micro/patch version described here. 2524 The `format_version` is important for any consumer software to understand how to parse the fields. 2525 """ 2526 2527 implemented_type: ClassVar[Literal["model"]] = "model" 2528 if TYPE_CHECKING: 2529 type: Literal["model"] = "model" 2530 else: 2531 type: Literal["model"] 2532 """Specialized resource type 'model'""" 2533 2534 id: Optional[ModelId] = None 2535 """bioimage.io-wide unique resource identifier 2536 assigned by bioimage.io; version **un**specific.""" 2537 2538 authors: NotEmpty[List[Author]] 2539 """The authors are the creators of the model RDF and the primary points of contact.""" 2540 2541 documentation: Annotated[ 2542 DocumentationSource, 2543 Field( 2544 examples=[ 2545 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2546 "README.md", 2547 ], 2548 ), 2549 ] 2550 """∈📦 URL or relative path to a markdown file with additional documentation. 2551 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2552 The documentation should include a '#[#] Validation' (sub)section 2553 with details on how to quantitatively validate the model on unseen data.""" 2554 2555 @field_validator("documentation", mode="after") 2556 @classmethod 2557 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2558 if not get_validation_context().perform_io_checks: 2559 return value 2560 2561 doc_path = download(value).path 2562 doc_content = doc_path.read_text(encoding="utf-8") 2563 assert isinstance(doc_content, str) 2564 if not re.search("#.*[vV]alidation", doc_content): 2565 issue_warning( 2566 "No '# Validation' (sub)section found in {value}.", 2567 value=value, 2568 field="documentation", 2569 ) 2570 2571 return value 2572 2573 inputs: NotEmpty[Sequence[InputTensorDescr]] 2574 """Describes the input tensors expected by this model.""" 2575 2576 @field_validator("inputs", mode="after") 2577 @classmethod 2578 def _validate_input_axes( 2579 cls, inputs: Sequence[InputTensorDescr] 2580 ) -> Sequence[InputTensorDescr]: 2581 input_size_refs = cls._get_axes_with_independent_size(inputs) 2582 2583 for i, ipt in enumerate(inputs): 2584 valid_independent_refs: Dict[ 2585 Tuple[TensorId, AxisId], 2586 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2587 ] = { 2588 **{ 2589 (ipt.id, a.id): (ipt, a, a.size) 2590 for a in ipt.axes 2591 if not isinstance(a, BatchAxis) 2592 and isinstance(a.size, (int, ParameterizedSize)) 2593 }, 2594 **input_size_refs, 2595 } 2596 for a, ax in enumerate(ipt.axes): 2597 cls._validate_axis( 2598 "inputs", 2599 i=i, 2600 tensor_id=ipt.id, 2601 a=a, 2602 axis=ax, 2603 valid_independent_refs=valid_independent_refs, 2604 ) 2605 return inputs 2606 2607 @staticmethod 2608 def _validate_axis( 2609 field_name: str, 2610 i: int, 2611 tensor_id: TensorId, 2612 a: int, 2613 axis: AnyAxis, 2614 valid_independent_refs: Dict[ 2615 Tuple[TensorId, AxisId], 2616 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2617 ], 2618 ): 2619 if isinstance(axis, BatchAxis) or isinstance( 2620 axis.size, (int, ParameterizedSize, DataDependentSize) 2621 ): 2622 return 2623 elif not isinstance(axis.size, SizeReference): 2624 assert_never(axis.size) 2625 2626 # validate axis.size SizeReference 2627 ref = (axis.size.tensor_id, axis.size.axis_id) 2628 if ref not in valid_independent_refs: 2629 raise ValueError( 2630 "Invalid tensor axis reference at" 2631 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2632 ) 2633 if ref == (tensor_id, axis.id): 2634 raise ValueError( 2635 "Self-referencing not allowed for" 2636 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2637 ) 2638 if axis.type == "channel": 2639 if valid_independent_refs[ref][1].type != "channel": 2640 raise ValueError( 2641 "A channel axis' size may only reference another fixed size" 2642 + " channel axis." 2643 ) 2644 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2645 ref_size = valid_independent_refs[ref][2] 2646 assert isinstance(ref_size, int), ( 2647 "channel axis ref (another channel axis) has to specify fixed" 2648 + " size" 2649 ) 2650 generated_channel_names = [ 2651 Identifier(axis.channel_names.format(i=i)) 2652 for i in range(1, ref_size + 1) 2653 ] 2654 axis.channel_names = generated_channel_names 2655 2656 if (ax_unit := getattr(axis, "unit", None)) != ( 2657 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2658 ): 2659 raise ValueError( 2660 "The units of an axis and its reference axis need to match, but" 2661 + f" '{ax_unit}' != '{ref_unit}'." 2662 ) 2663 ref_axis = valid_independent_refs[ref][1] 2664 if isinstance(ref_axis, BatchAxis): 2665 raise ValueError( 2666 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2667 + " (a batch axis is not allowed as reference)." 2668 ) 2669 2670 if isinstance(axis, WithHalo): 2671 min_size = axis.size.get_size(axis, ref_axis, n=0) 2672 if (min_size - 2 * axis.halo) < 1: 2673 raise ValueError( 2674 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2675 + f" {axis.halo}." 2676 ) 2677 2678 input_halo = axis.halo * axis.scale / ref_axis.scale 2679 if input_halo != int(input_halo) or input_halo % 2 == 1: 2680 raise ValueError( 2681 f"input_halo {input_halo} (output_halo {axis.halo} *" 2682 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2683 + f" {tensor_id}.{axis.id}." 2684 ) 2685 2686 @model_validator(mode="after") 2687 def _validate_test_tensors(self) -> Self: 2688 if not get_validation_context().perform_io_checks: 2689 return self 2690 2691 test_output_arrays = [ 2692 load_array(descr.test_tensor.download().path) for descr in self.outputs 2693 ] 2694 test_input_arrays = [ 2695 load_array(descr.test_tensor.download().path) for descr in self.inputs 2696 ] 2697 2698 tensors = { 2699 descr.id: (descr, array) 2700 for descr, array in zip( 2701 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2702 ) 2703 } 2704 validate_tensors(tensors, tensor_origin="test_tensor") 2705 2706 output_arrays = { 2707 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2708 } 2709 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2710 if not rep_tol.absolute_tolerance: 2711 continue 2712 2713 if rep_tol.output_ids: 2714 out_arrays = { 2715 oid: a 2716 for oid, a in output_arrays.items() 2717 if oid in rep_tol.output_ids 2718 } 2719 else: 2720 out_arrays = output_arrays 2721 2722 for out_id, array in out_arrays.items(): 2723 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2724 raise ValueError( 2725 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2726 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2727 + f" (1% of the maximum value of the test tensor '{out_id}')" 2728 ) 2729 2730 return self 2731 2732 @model_validator(mode="after") 2733 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2734 ipt_refs = {t.id for t in self.inputs} 2735 out_refs = {t.id for t in self.outputs} 2736 for ipt in self.inputs: 2737 for p in ipt.preprocessing: 2738 ref = p.kwargs.get("reference_tensor") 2739 if ref is None: 2740 continue 2741 if ref not in ipt_refs: 2742 raise ValueError( 2743 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2744 + f" references are: {ipt_refs}." 2745 ) 2746 2747 for out in self.outputs: 2748 for p in out.postprocessing: 2749 ref = p.kwargs.get("reference_tensor") 2750 if ref is None: 2751 continue 2752 2753 if ref not in ipt_refs and ref not in out_refs: 2754 raise ValueError( 2755 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2756 + f" are: {ipt_refs | out_refs}." 2757 ) 2758 2759 return self 2760 2761 # TODO: use validate funcs in validate_test_tensors 2762 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2763 2764 name: Annotated[ 2765 Annotated[ 2766 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2767 ], 2768 MinLen(5), 2769 MaxLen(128), 2770 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2771 ] 2772 """A human-readable name of this model. 2773 It should be no longer than 64 characters 2774 and may only contain letter, number, underscore, minus, parentheses and spaces. 2775 We recommend to chose a name that refers to the model's task and image modality. 2776 """ 2777 2778 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2779 """Describes the output tensors.""" 2780 2781 @field_validator("outputs", mode="after") 2782 @classmethod 2783 def _validate_tensor_ids( 2784 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2785 ) -> Sequence[OutputTensorDescr]: 2786 tensor_ids = [ 2787 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2788 ] 2789 duplicate_tensor_ids: List[str] = [] 2790 seen: Set[str] = set() 2791 for t in tensor_ids: 2792 if t in seen: 2793 duplicate_tensor_ids.append(t) 2794 2795 seen.add(t) 2796 2797 if duplicate_tensor_ids: 2798 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2799 2800 return outputs 2801 2802 @staticmethod 2803 def _get_axes_with_parameterized_size( 2804 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2805 ): 2806 return { 2807 f"{t.id}.{a.id}": (t, a, a.size) 2808 for t in io 2809 for a in t.axes 2810 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2811 } 2812 2813 @staticmethod 2814 def _get_axes_with_independent_size( 2815 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2816 ): 2817 return { 2818 (t.id, a.id): (t, a, a.size) 2819 for t in io 2820 for a in t.axes 2821 if not isinstance(a, BatchAxis) 2822 and isinstance(a.size, (int, ParameterizedSize)) 2823 } 2824 2825 @field_validator("outputs", mode="after") 2826 @classmethod 2827 def _validate_output_axes( 2828 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2829 ) -> List[OutputTensorDescr]: 2830 input_size_refs = cls._get_axes_with_independent_size( 2831 info.data.get("inputs", []) 2832 ) 2833 output_size_refs = cls._get_axes_with_independent_size(outputs) 2834 2835 for i, out in enumerate(outputs): 2836 valid_independent_refs: Dict[ 2837 Tuple[TensorId, AxisId], 2838 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2839 ] = { 2840 **{ 2841 (out.id, a.id): (out, a, a.size) 2842 for a in out.axes 2843 if not isinstance(a, BatchAxis) 2844 and isinstance(a.size, (int, ParameterizedSize)) 2845 }, 2846 **input_size_refs, 2847 **output_size_refs, 2848 } 2849 for a, ax in enumerate(out.axes): 2850 cls._validate_axis( 2851 "outputs", 2852 i, 2853 out.id, 2854 a, 2855 ax, 2856 valid_independent_refs=valid_independent_refs, 2857 ) 2858 2859 return outputs 2860 2861 packaged_by: List[Author] = Field(default_factory=list) 2862 """The persons that have packaged and uploaded this model. 2863 Only required if those persons differ from the `authors`.""" 2864 2865 parent: Optional[LinkedModel] = None 2866 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2867 2868 @model_validator(mode="after") 2869 def _validate_parent_is_not_self(self) -> Self: 2870 if self.parent is not None and self.parent.id == self.id: 2871 raise ValueError("A model description may not reference itself as parent.") 2872 2873 return self 2874 2875 run_mode: Annotated[ 2876 Optional[RunMode], 2877 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2878 ] = None 2879 """Custom run mode for this model: for more complex prediction procedures like test time 2880 data augmentation that currently cannot be expressed in the specification. 2881 No standard run modes are defined yet.""" 2882 2883 timestamp: Datetime = Field(default_factory=Datetime.now) 2884 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2885 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2886 (In Python a datetime object is valid, too).""" 2887 2888 training_data: Annotated[ 2889 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2890 Field(union_mode="left_to_right"), 2891 ] = None 2892 """The dataset used to train this model""" 2893 2894 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2895 """The weights for this model. 2896 Weights can be given for different formats, but should otherwise be equivalent. 2897 The available weight formats determine which consumers can use this model.""" 2898 2899 config: Config = Field(default_factory=Config) 2900 2901 @model_validator(mode="after") 2902 def _add_default_cover(self) -> Self: 2903 if not get_validation_context().perform_io_checks or self.covers: 2904 return self 2905 2906 try: 2907 generated_covers = generate_covers( 2908 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2909 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2910 ) 2911 except Exception as e: 2912 issue_warning( 2913 "Failed to generate cover image(s): {e}", 2914 value=self.covers, 2915 msg_context=dict(e=e), 2916 field="covers", 2917 ) 2918 else: 2919 self.covers.extend(generated_covers) 2920 2921 return self 2922 2923 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2924 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2925 assert all(isinstance(d, np.ndarray) for d in data) 2926 return data 2927 2928 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2929 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2930 assert all(isinstance(d, np.ndarray) for d in data) 2931 return data 2932 2933 @staticmethod 2934 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2935 batch_size = 1 2936 tensor_with_batchsize: Optional[TensorId] = None 2937 for tid in tensor_sizes: 2938 for aid, s in tensor_sizes[tid].items(): 2939 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2940 continue 2941 2942 if batch_size != 1: 2943 assert tensor_with_batchsize is not None 2944 raise ValueError( 2945 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2946 ) 2947 2948 batch_size = s 2949 tensor_with_batchsize = tid 2950 2951 return batch_size 2952 2953 def get_output_tensor_sizes( 2954 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2955 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2956 """Returns the tensor output sizes for given **input_sizes**. 2957 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2958 Otherwise it might be larger than the actual (valid) output""" 2959 batch_size = self.get_batch_size(input_sizes) 2960 ns = self.get_ns(input_sizes) 2961 2962 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2963 return tensor_sizes.outputs 2964 2965 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2966 """get parameter `n` for each parameterized axis 2967 such that the valid input size is >= the given input size""" 2968 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2969 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2970 for tid in input_sizes: 2971 for aid, s in input_sizes[tid].items(): 2972 size_descr = axes[tid][aid].size 2973 if isinstance(size_descr, ParameterizedSize): 2974 ret[(tid, aid)] = size_descr.get_n(s) 2975 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2976 pass 2977 else: 2978 assert_never(size_descr) 2979 2980 return ret 2981 2982 def get_tensor_sizes( 2983 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2984 ) -> _TensorSizes: 2985 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2986 return _TensorSizes( 2987 { 2988 t: { 2989 aa: axis_sizes.inputs[(tt, aa)] 2990 for tt, aa in axis_sizes.inputs 2991 if tt == t 2992 } 2993 for t in {tt for tt, _ in axis_sizes.inputs} 2994 }, 2995 { 2996 t: { 2997 aa: axis_sizes.outputs[(tt, aa)] 2998 for tt, aa in axis_sizes.outputs 2999 if tt == t 3000 } 3001 for t in {tt for tt, _ in axis_sizes.outputs} 3002 }, 3003 ) 3004 3005 def get_axis_sizes( 3006 self, 3007 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3008 batch_size: Optional[int] = None, 3009 *, 3010 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3011 ) -> _AxisSizes: 3012 """Determine input and output block shape for scale factors **ns** 3013 of parameterized input sizes. 3014 3015 Args: 3016 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3017 that is parameterized as `size = min + n * step`. 3018 batch_size: The desired size of the batch dimension. 3019 If given **batch_size** overwrites any batch size present in 3020 **max_input_shape**. Default 1. 3021 max_input_shape: Limits the derived block shapes. 3022 Each axis for which the input size, parameterized by `n`, is larger 3023 than **max_input_shape** is set to the minimal value `n_min` for which 3024 this is still true. 3025 Use this for small input samples or large values of **ns**. 3026 Or simply whenever you know the full input shape. 3027 3028 Returns: 3029 Resolved axis sizes for model inputs and outputs. 3030 """ 3031 max_input_shape = max_input_shape or {} 3032 if batch_size is None: 3033 for (_t_id, a_id), s in max_input_shape.items(): 3034 if a_id == BATCH_AXIS_ID: 3035 batch_size = s 3036 break 3037 else: 3038 batch_size = 1 3039 3040 all_axes = { 3041 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3042 } 3043 3044 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3045 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3046 3047 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3048 if isinstance(a, BatchAxis): 3049 if (t_descr.id, a.id) in ns: 3050 logger.warning( 3051 "Ignoring unexpected size increment factor (n) for batch axis" 3052 + " of tensor '{}'.", 3053 t_descr.id, 3054 ) 3055 return batch_size 3056 elif isinstance(a.size, int): 3057 if (t_descr.id, a.id) in ns: 3058 logger.warning( 3059 "Ignoring unexpected size increment factor (n) for fixed size" 3060 + " axis '{}' of tensor '{}'.", 3061 a.id, 3062 t_descr.id, 3063 ) 3064 return a.size 3065 elif isinstance(a.size, ParameterizedSize): 3066 if (t_descr.id, a.id) not in ns: 3067 raise ValueError( 3068 "Size increment factor (n) missing for parametrized axis" 3069 + f" '{a.id}' of tensor '{t_descr.id}'." 3070 ) 3071 n = ns[(t_descr.id, a.id)] 3072 s_max = max_input_shape.get((t_descr.id, a.id)) 3073 if s_max is not None: 3074 n = min(n, a.size.get_n(s_max)) 3075 3076 return a.size.get_size(n) 3077 3078 elif isinstance(a.size, SizeReference): 3079 if (t_descr.id, a.id) in ns: 3080 logger.warning( 3081 "Ignoring unexpected size increment factor (n) for axis '{}'" 3082 + " of tensor '{}' with size reference.", 3083 a.id, 3084 t_descr.id, 3085 ) 3086 assert not isinstance(a, BatchAxis) 3087 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3088 assert not isinstance(ref_axis, BatchAxis) 3089 ref_key = (a.size.tensor_id, a.size.axis_id) 3090 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3091 assert ref_size is not None, ref_key 3092 assert not isinstance(ref_size, _DataDepSize), ref_key 3093 return a.size.get_size( 3094 axis=a, 3095 ref_axis=ref_axis, 3096 ref_size=ref_size, 3097 ) 3098 elif isinstance(a.size, DataDependentSize): 3099 if (t_descr.id, a.id) in ns: 3100 logger.warning( 3101 "Ignoring unexpected increment factor (n) for data dependent" 3102 + " size axis '{}' of tensor '{}'.", 3103 a.id, 3104 t_descr.id, 3105 ) 3106 return _DataDepSize(a.size.min, a.size.max) 3107 else: 3108 assert_never(a.size) 3109 3110 # first resolve all , but the `SizeReference` input sizes 3111 for t_descr in self.inputs: 3112 for a in t_descr.axes: 3113 if not isinstance(a.size, SizeReference): 3114 s = get_axis_size(a) 3115 assert not isinstance(s, _DataDepSize) 3116 inputs[t_descr.id, a.id] = s 3117 3118 # resolve all other input axis sizes 3119 for t_descr in self.inputs: 3120 for a in t_descr.axes: 3121 if isinstance(a.size, SizeReference): 3122 s = get_axis_size(a) 3123 assert not isinstance(s, _DataDepSize) 3124 inputs[t_descr.id, a.id] = s 3125 3126 # resolve all output axis sizes 3127 for t_descr in self.outputs: 3128 for a in t_descr.axes: 3129 assert not isinstance(a.size, ParameterizedSize) 3130 s = get_axis_size(a) 3131 outputs[t_descr.id, a.id] = s 3132 3133 return _AxisSizes(inputs=inputs, outputs=outputs) 3134 3135 @model_validator(mode="before") 3136 @classmethod 3137 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3138 cls.convert_from_old_format_wo_validation(data) 3139 return data 3140 3141 @classmethod 3142 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3143 """Convert metadata following an older format version to this classes' format 3144 without validating the result. 3145 """ 3146 if ( 3147 data.get("type") == "model" 3148 and isinstance(fv := data.get("format_version"), str) 3149 and fv.count(".") == 2 3150 ): 3151 fv_parts = fv.split(".") 3152 if any(not p.isdigit() for p in fv_parts): 3153 return 3154 3155 fv_tuple = tuple(map(int, fv_parts)) 3156 3157 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3158 if fv_tuple[:2] in ((0, 3), (0, 4)): 3159 m04 = _ModelDescr_v0_4.load(data) 3160 if isinstance(m04, InvalidDescr): 3161 try: 3162 updated = _model_conv.convert_as_dict( 3163 m04 # pyright: ignore[reportArgumentType] 3164 ) 3165 except Exception as e: 3166 logger.error( 3167 "Failed to convert from invalid model 0.4 description." 3168 + f"\nerror: {e}" 3169 + "\nProceeding with model 0.5 validation without conversion." 3170 ) 3171 updated = None 3172 else: 3173 updated = _model_conv.convert_as_dict(m04) 3174 3175 if updated is not None: 3176 data.clear() 3177 data.update(updated) 3178 3179 elif fv_tuple[:2] == (0, 5): 3180 # bump patch version 3181 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2933 @staticmethod 2934 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2935 batch_size = 1 2936 tensor_with_batchsize: Optional[TensorId] = None 2937 for tid in tensor_sizes: 2938 for aid, s in tensor_sizes[tid].items(): 2939 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2940 continue 2941 2942 if batch_size != 1: 2943 assert tensor_with_batchsize is not None 2944 raise ValueError( 2945 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2946 ) 2947 2948 batch_size = s 2949 tensor_with_batchsize = tid 2950 2951 return batch_size
2953 def get_output_tensor_sizes( 2954 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2955 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2956 """Returns the tensor output sizes for given **input_sizes**. 2957 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2958 Otherwise it might be larger than the actual (valid) output""" 2959 batch_size = self.get_batch_size(input_sizes) 2960 ns = self.get_ns(input_sizes) 2961 2962 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2963 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2965 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2966 """get parameter `n` for each parameterized axis 2967 such that the valid input size is >= the given input size""" 2968 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2969 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2970 for tid in input_sizes: 2971 for aid, s in input_sizes[tid].items(): 2972 size_descr = axes[tid][aid].size 2973 if isinstance(size_descr, ParameterizedSize): 2974 ret[(tid, aid)] = size_descr.get_n(s) 2975 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2976 pass 2977 else: 2978 assert_never(size_descr) 2979 2980 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2982 def get_tensor_sizes( 2983 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2984 ) -> _TensorSizes: 2985 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2986 return _TensorSizes( 2987 { 2988 t: { 2989 aa: axis_sizes.inputs[(tt, aa)] 2990 for tt, aa in axis_sizes.inputs 2991 if tt == t 2992 } 2993 for t in {tt for tt, _ in axis_sizes.inputs} 2994 }, 2995 { 2996 t: { 2997 aa: axis_sizes.outputs[(tt, aa)] 2998 for tt, aa in axis_sizes.outputs 2999 if tt == t 3000 } 3001 for t in {tt for tt, _ in axis_sizes.outputs} 3002 }, 3003 )
3005 def get_axis_sizes( 3006 self, 3007 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3008 batch_size: Optional[int] = None, 3009 *, 3010 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3011 ) -> _AxisSizes: 3012 """Determine input and output block shape for scale factors **ns** 3013 of parameterized input sizes. 3014 3015 Args: 3016 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3017 that is parameterized as `size = min + n * step`. 3018 batch_size: The desired size of the batch dimension. 3019 If given **batch_size** overwrites any batch size present in 3020 **max_input_shape**. Default 1. 3021 max_input_shape: Limits the derived block shapes. 3022 Each axis for which the input size, parameterized by `n`, is larger 3023 than **max_input_shape** is set to the minimal value `n_min` for which 3024 this is still true. 3025 Use this for small input samples or large values of **ns**. 3026 Or simply whenever you know the full input shape. 3027 3028 Returns: 3029 Resolved axis sizes for model inputs and outputs. 3030 """ 3031 max_input_shape = max_input_shape or {} 3032 if batch_size is None: 3033 for (_t_id, a_id), s in max_input_shape.items(): 3034 if a_id == BATCH_AXIS_ID: 3035 batch_size = s 3036 break 3037 else: 3038 batch_size = 1 3039 3040 all_axes = { 3041 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3042 } 3043 3044 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3045 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3046 3047 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3048 if isinstance(a, BatchAxis): 3049 if (t_descr.id, a.id) in ns: 3050 logger.warning( 3051 "Ignoring unexpected size increment factor (n) for batch axis" 3052 + " of tensor '{}'.", 3053 t_descr.id, 3054 ) 3055 return batch_size 3056 elif isinstance(a.size, int): 3057 if (t_descr.id, a.id) in ns: 3058 logger.warning( 3059 "Ignoring unexpected size increment factor (n) for fixed size" 3060 + " axis '{}' of tensor '{}'.", 3061 a.id, 3062 t_descr.id, 3063 ) 3064 return a.size 3065 elif isinstance(a.size, ParameterizedSize): 3066 if (t_descr.id, a.id) not in ns: 3067 raise ValueError( 3068 "Size increment factor (n) missing for parametrized axis" 3069 + f" '{a.id}' of tensor '{t_descr.id}'." 3070 ) 3071 n = ns[(t_descr.id, a.id)] 3072 s_max = max_input_shape.get((t_descr.id, a.id)) 3073 if s_max is not None: 3074 n = min(n, a.size.get_n(s_max)) 3075 3076 return a.size.get_size(n) 3077 3078 elif isinstance(a.size, SizeReference): 3079 if (t_descr.id, a.id) in ns: 3080 logger.warning( 3081 "Ignoring unexpected size increment factor (n) for axis '{}'" 3082 + " of tensor '{}' with size reference.", 3083 a.id, 3084 t_descr.id, 3085 ) 3086 assert not isinstance(a, BatchAxis) 3087 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3088 assert not isinstance(ref_axis, BatchAxis) 3089 ref_key = (a.size.tensor_id, a.size.axis_id) 3090 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3091 assert ref_size is not None, ref_key 3092 assert not isinstance(ref_size, _DataDepSize), ref_key 3093 return a.size.get_size( 3094 axis=a, 3095 ref_axis=ref_axis, 3096 ref_size=ref_size, 3097 ) 3098 elif isinstance(a.size, DataDependentSize): 3099 if (t_descr.id, a.id) in ns: 3100 logger.warning( 3101 "Ignoring unexpected increment factor (n) for data dependent" 3102 + " size axis '{}' of tensor '{}'.", 3103 a.id, 3104 t_descr.id, 3105 ) 3106 return _DataDepSize(a.size.min, a.size.max) 3107 else: 3108 assert_never(a.size) 3109 3110 # first resolve all , but the `SizeReference` input sizes 3111 for t_descr in self.inputs: 3112 for a in t_descr.axes: 3113 if not isinstance(a.size, SizeReference): 3114 s = get_axis_size(a) 3115 assert not isinstance(s, _DataDepSize) 3116 inputs[t_descr.id, a.id] = s 3117 3118 # resolve all other input axis sizes 3119 for t_descr in self.inputs: 3120 for a in t_descr.axes: 3121 if isinstance(a.size, SizeReference): 3122 s = get_axis_size(a) 3123 assert not isinstance(s, _DataDepSize) 3124 inputs[t_descr.id, a.id] = s 3125 3126 # resolve all output axis sizes 3127 for t_descr in self.outputs: 3128 for a in t_descr.axes: 3129 assert not isinstance(a.size, ParameterizedSize) 3130 s = get_axis_size(a) 3131 outputs[t_descr.id, a.id] = s 3132 3133 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3141 @classmethod 3142 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3143 """Convert metadata following an older format version to this classes' format 3144 without validating the result. 3145 """ 3146 if ( 3147 data.get("type") == "model" 3148 and isinstance(fv := data.get("format_version"), str) 3149 and fv.count(".") == 2 3150 ): 3151 fv_parts = fv.split(".") 3152 if any(not p.isdigit() for p in fv_parts): 3153 return 3154 3155 fv_tuple = tuple(map(int, fv_parts)) 3156 3157 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3158 if fv_tuple[:2] in ((0, 3), (0, 4)): 3159 m04 = _ModelDescr_v0_4.load(data) 3160 if isinstance(m04, InvalidDescr): 3161 try: 3162 updated = _model_conv.convert_as_dict( 3163 m04 # pyright: ignore[reportArgumentType] 3164 ) 3165 except Exception as e: 3166 logger.error( 3167 "Failed to convert from invalid model 0.4 description." 3168 + f"\nerror: {e}" 3169 + "\nProceeding with model 0.5 validation without conversion." 3170 ) 3171 updated = None 3172 else: 3173 updated = _model_conv.convert_as_dict(m04) 3174 3175 if updated is not None: 3176 data.clear() 3177 data.update(updated) 3178 3179 elif fv_tuple[:2] == (0, 5): 3180 # bump patch version 3181 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
3406def generate_covers( 3407 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3408 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3409) -> List[Path]: 3410 def squeeze( 3411 data: NDArray[Any], axes: Sequence[AnyAxis] 3412 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3413 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3414 if data.ndim != len(axes): 3415 raise ValueError( 3416 f"tensor shape {data.shape} does not match described axes" 3417 + f" {[a.id for a in axes]}" 3418 ) 3419 3420 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3421 return data.squeeze(), axes 3422 3423 def normalize( 3424 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3425 ) -> NDArray[np.float32]: 3426 data = data.astype("float32") 3427 data -= data.min(axis=axis, keepdims=True) 3428 data /= data.max(axis=axis, keepdims=True) + eps 3429 return data 3430 3431 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3432 original_shape = data.shape 3433 data, axes = squeeze(data, axes) 3434 3435 # take slice fom any batch or index axis if needed 3436 # and convert the first channel axis and take a slice from any additional channel axes 3437 slices: Tuple[slice, ...] = () 3438 ndim = data.ndim 3439 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3440 has_c_axis = False 3441 for i, a in enumerate(axes): 3442 s = data.shape[i] 3443 assert s > 1 3444 if ( 3445 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3446 and ndim > ndim_need 3447 ): 3448 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3449 ndim -= 1 3450 elif isinstance(a, ChannelAxis): 3451 if has_c_axis: 3452 # second channel axis 3453 data = data[slices + (slice(0, 1),)] 3454 ndim -= 1 3455 else: 3456 has_c_axis = True 3457 if s == 2: 3458 # visualize two channels with cyan and magenta 3459 data = np.concatenate( 3460 [ 3461 data[slices + (slice(1, 2),)], 3462 data[slices + (slice(0, 1),)], 3463 ( 3464 data[slices + (slice(0, 1),)] 3465 + data[slices + (slice(1, 2),)] 3466 ) 3467 / 2, # TODO: take maximum instead? 3468 ], 3469 axis=i, 3470 ) 3471 elif data.shape[i] == 3: 3472 pass # visualize 3 channels as RGB 3473 else: 3474 # visualize first 3 channels as RGB 3475 data = data[slices + (slice(3),)] 3476 3477 assert data.shape[i] == 3 3478 3479 slices += (slice(None),) 3480 3481 data, axes = squeeze(data, axes) 3482 assert len(axes) == ndim 3483 # take slice from z axis if needed 3484 slices = () 3485 if ndim > ndim_need: 3486 for i, a in enumerate(axes): 3487 s = data.shape[i] 3488 if a.id == AxisId("z"): 3489 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3490 data, axes = squeeze(data, axes) 3491 ndim -= 1 3492 break 3493 3494 slices += (slice(None),) 3495 3496 # take slice from any space or time axis 3497 slices = () 3498 3499 for i, a in enumerate(axes): 3500 if ndim <= ndim_need: 3501 break 3502 3503 s = data.shape[i] 3504 assert s > 1 3505 if isinstance( 3506 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3507 ): 3508 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3509 ndim -= 1 3510 3511 slices += (slice(None),) 3512 3513 del slices 3514 data, axes = squeeze(data, axes) 3515 assert len(axes) == ndim 3516 3517 if (has_c_axis and ndim != 3) or ndim != 2: 3518 raise ValueError( 3519 f"Failed to construct cover image from shape {original_shape}" 3520 ) 3521 3522 if not has_c_axis: 3523 assert ndim == 2 3524 data = np.repeat(data[:, :, None], 3, axis=2) 3525 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3526 ndim += 1 3527 3528 assert ndim == 3 3529 3530 # transpose axis order such that longest axis comes first... 3531 axis_order = list(np.argsort(list(data.shape))) 3532 axis_order.reverse() 3533 # ... and channel axis is last 3534 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3535 axis_order.append(axis_order.pop(c)) 3536 axes = [axes[ao] for ao in axis_order] 3537 data = data.transpose(axis_order) 3538 3539 # h, w = data.shape[:2] 3540 # if h / w in (1.0 or 2.0): 3541 # pass 3542 # elif h / w < 2: 3543 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3544 3545 norm_along = ( 3546 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3547 ) 3548 # normalize the data and map to 8 bit 3549 data = normalize(data, norm_along) 3550 data = (data * 255).astype("uint8") 3551 3552 return data 3553 3554 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3555 assert im0.dtype == im1.dtype == np.uint8 3556 assert im0.shape == im1.shape 3557 assert im0.ndim == 3 3558 N, M, C = im0.shape 3559 assert C == 3 3560 out = np.ones((N, M, C), dtype="uint8") 3561 for c in range(C): 3562 outc = np.tril(im0[..., c]) 3563 mask = outc == 0 3564 outc[mask] = np.triu(im1[..., c])[mask] 3565 out[..., c] = outc 3566 3567 return out 3568 3569 ipt_descr, ipt = inputs[0] 3570 out_descr, out = outputs[0] 3571 3572 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3573 out_img = to_2d_image(out, out_descr.axes) 3574 3575 cover_folder = Path(mkdtemp()) 3576 if ipt_img.shape == out_img.shape: 3577 covers = [cover_folder / "cover.png"] 3578 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3579 else: 3580 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3581 imwrite(covers[0], ipt_img) 3582 imwrite(covers[1], out_img) 3583 3584 return covers