bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from datetime import datetime 10from itertools import chain 11from math import ceil 12from pathlib import Path, PurePosixPath 13from tempfile import mkdtemp 14from typing import ( 15 TYPE_CHECKING, 16 Any, 17 ClassVar, 18 Dict, 19 FrozenSet, 20 Generic, 21 List, 22 Literal, 23 Mapping, 24 NamedTuple, 25 Optional, 26 Sequence, 27 Set, 28 Tuple, 29 Type, 30 TypeVar, 31 Union, 32 cast, 33) 34 35import numpy as np 36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 38from loguru import logger 39from numpy.typing import NDArray 40from pydantic import ( 41 Discriminator, 42 Field, 43 RootModel, 44 Tag, 45 ValidationInfo, 46 WrapSerializer, 47 field_validator, 48 model_validator, 49) 50from typing_extensions import Annotated, LiteralString, Self, assert_never, get_args 51 52from .._internal.common_nodes import ( 53 InvalidDescr, 54 Node, 55 NodeWithExplicitlySetFields, 56) 57from .._internal.constants import DTYPE_LIMITS 58from .._internal.field_warning import issue_warning, warn 59from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 60from .._internal.io import FileDescr as FileDescr 61from .._internal.io import WithSuffix, YamlValue, download 62from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 63from .._internal.io_basics import Sha256 as Sha256 64from .._internal.io_utils import load_array 65from .._internal.node_converter import Converter 66from .._internal.types import Datetime as Datetime 67from .._internal.types import Identifier as Identifier 68from .._internal.types import ( 69 ImportantFileSource, 70 LowerCaseIdentifier, 71 LowerCaseIdentifierAnno, 72 SiUnit, 73) 74from .._internal.types import NotEmpty as NotEmpty 75from .._internal.url import HttpUrl as HttpUrl 76from .._internal.validation_context import validation_context_var 77from .._internal.validator_annotations import RestrictCharacters 78from .._internal.version_type import Version as Version 79from .._internal.warning_levels import INFO 80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 82from ..dataset.v0_3 import DatasetDescr as DatasetDescr 83from ..dataset.v0_3 import DatasetId as DatasetId 84from ..dataset.v0_3 import LinkedDataset as LinkedDataset 85from ..dataset.v0_3 import Uploader as Uploader 86from ..generic.v0_3 import ( 87 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 88) 89from ..generic.v0_3 import Author as Author 90from ..generic.v0_3 import BadgeDescr as BadgeDescr 91from ..generic.v0_3 import CiteEntry as CiteEntry 92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 93from ..generic.v0_3 import ( 94 DocumentationSource, 95 GenericModelDescrBase, 96 LinkedResourceBase, 97 _author_conv, # pyright: ignore[reportPrivateUsage] 98 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 99) 100from ..generic.v0_3 import Doi as Doi 101from ..generic.v0_3 import LicenseId as LicenseId 102from ..generic.v0_3 import LinkedResource as LinkedResource 103from ..generic.v0_3 import Maintainer as Maintainer 104from ..generic.v0_3 import OrcidId as OrcidId 105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 106from ..generic.v0_3 import ResourceId as ResourceId 107from .v0_4 import Author as _Author_v0_4 108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 109from .v0_4 import CallableFromDepencency as CallableFromDepencency 110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 112from .v0_4 import ClipDescr as _ClipDescr_v0_4 113from .v0_4 import ClipKwargs as ClipKwargs 114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 116from .v0_4 import KnownRunMode as KnownRunMode 117from .v0_4 import ModelDescr as _ModelDescr_v0_4 118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 122from .v0_4 import ProcessingKwargs as ProcessingKwargs 123from .v0_4 import RunMode as RunMode 124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 128from .v0_4 import TensorName as _TensorName_v0_4 129from .v0_4 import WeightsFormat as WeightsFormat 130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 131from .v0_4 import package_weights 132 133SpaceUnit = Literal[ 134 "attometer", 135 "angstrom", 136 "centimeter", 137 "decimeter", 138 "exameter", 139 "femtometer", 140 "foot", 141 "gigameter", 142 "hectometer", 143 "inch", 144 "kilometer", 145 "megameter", 146 "meter", 147 "micrometer", 148 "mile", 149 "millimeter", 150 "nanometer", 151 "parsec", 152 "petameter", 153 "picometer", 154 "terameter", 155 "yard", 156 "yoctometer", 157 "yottameter", 158 "zeptometer", 159 "zettameter", 160] 161"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 162 163TimeUnit = Literal[ 164 "attosecond", 165 "centisecond", 166 "day", 167 "decisecond", 168 "exasecond", 169 "femtosecond", 170 "gigasecond", 171 "hectosecond", 172 "hour", 173 "kilosecond", 174 "megasecond", 175 "microsecond", 176 "millisecond", 177 "minute", 178 "nanosecond", 179 "petasecond", 180 "picosecond", 181 "second", 182 "terasecond", 183 "yoctosecond", 184 "yottasecond", 185 "zeptosecond", 186 "zettasecond", 187] 188"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 189 190AxisType = Literal["batch", "channel", "index", "time", "space"] 191 192 193class TensorId(LowerCaseIdentifier): 194 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 195 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 196 ] 197 198 199class AxisId(LowerCaseIdentifier): 200 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 201 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 202 ] 203 204 205def _is_batch(a: str) -> bool: 206 return a == BATCH_AXIS_ID 207 208 209def _is_not_batch(a: str) -> bool: 210 return not _is_batch(a) 211 212 213NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 214 215PostprocessingId = Literal[ 216 "binarize", 217 "clip", 218 "ensure_dtype", 219 "fixed_zero_mean_unit_variance", 220 "scale_linear", 221 "scale_mean_variance", 222 "scale_range", 223 "sigmoid", 224 "zero_mean_unit_variance", 225] 226PreprocessingId = Literal[ 227 "binarize", 228 "clip", 229 "ensure_dtype", 230 "scale_linear", 231 "sigmoid", 232 "zero_mean_unit_variance", 233 "scale_range", 234] 235 236 237SAME_AS_TYPE = "<same as type>" 238 239 240ParameterizedSize_N = int 241 242 243class ParameterizedSize(Node): 244 """Describes a range of valid tensor axis sizes as `size = min + n*step`.""" 245 246 N: ClassVar[Type[int]] = ParameterizedSize_N 247 """integer to parameterize this axis""" 248 249 min: Annotated[int, Gt(0)] 250 step: Annotated[int, Gt(0)] 251 252 def validate_size(self, size: int) -> int: 253 if size < self.min: 254 raise ValueError(f"size {size} < {self.min}") 255 if (size - self.min) % self.step != 0: 256 raise ValueError( 257 f"axis of size {size} is not parameterized by `min + n*step` =" 258 + f" `{self.min} + n*{self.step}`" 259 ) 260 261 return size 262 263 def get_size(self, n: ParameterizedSize_N) -> int: 264 return self.min + self.step * n 265 266 def get_n(self, s: int) -> ParameterizedSize_N: 267 """return smallest n parameterizing a size greater or equal than `s`""" 268 return ceil((s - self.min) / self.step) 269 270 271class DataDependentSize(Node): 272 min: Annotated[int, Gt(0)] = 1 273 max: Annotated[Optional[int], Gt(1)] = None 274 275 @model_validator(mode="after") 276 def _validate_max_gt_min(self): 277 if self.max is not None and self.min >= self.max: 278 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 279 280 return self 281 282 def validate_size(self, size: int) -> int: 283 if size < self.min: 284 raise ValueError(f"size {size} < {self.min}") 285 286 if self.max is not None and size > self.max: 287 raise ValueError(f"size {size} > {self.max}") 288 289 return size 290 291 292class SizeReference(Node): 293 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 294 295 `axis.size = reference.size * reference.scale / axis.scale + offset` 296 297 Note: 298 1. The axis and the referenced axis need to have the same unit (or no unit). 299 2. Batch axes may not be referenced. 300 3. Fractions are rounded down. 301 4. If the reference axis is `concatenable` the referencing axis is assumed to be 302 `concatenable` as well with the same block order. 303 304 Example: 305 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 306 Let's assume that we want to express the image height h in relation to its width w 307 instead of only accepting input images of exactly 100*49 pixels 308 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 309 310 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 311 >>> h = SpaceInputAxis( 312 ... id=AxisId("h"), 313 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 314 ... unit="millimeter", 315 ... scale=4, 316 ... ) 317 >>> print(h.size.get_size(h, w)) 318 49 319 320 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 321 """ 322 323 tensor_id: TensorId 324 """tensor id of the reference axis""" 325 326 axis_id: AxisId 327 """axis id of the reference axis""" 328 329 offset: int = 0 330 331 def get_size( 332 self, 333 axis: Union[ 334 ChannelAxis, 335 IndexInputAxis, 336 IndexOutputAxis, 337 TimeInputAxis, 338 SpaceInputAxis, 339 TimeOutputAxis, 340 TimeOutputAxisWithHalo, 341 SpaceOutputAxis, 342 SpaceOutputAxisWithHalo, 343 ], 344 ref_axis: Union[ 345 ChannelAxis, 346 IndexInputAxis, 347 IndexOutputAxis, 348 TimeInputAxis, 349 SpaceInputAxis, 350 TimeOutputAxis, 351 TimeOutputAxisWithHalo, 352 SpaceOutputAxis, 353 SpaceOutputAxisWithHalo, 354 ], 355 n: ParameterizedSize_N = 0, 356 ref_size: Optional[int] = None, 357 ): 358 """Compute the concrete size for a given axis and its reference axis. 359 360 Args: 361 axis: The axis this `SizeReference` is the size of. 362 ref_axis: The reference axis to compute the size from. 363 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 364 and no fixed **ref_size** is given, 365 **n** is used to compute the size of the parameterized **ref_axis**. 366 ref_size: Overwrite the reference size instead of deriving it from 367 **ref_axis** 368 (**ref_axis.scale** is still used; any given **n** is ignored). 369 """ 370 assert ( 371 axis.size == self 372 ), "Given `axis.size` is not defined by this `SizeReference`" 373 374 assert ( 375 ref_axis.id == self.axis_id 376 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 377 378 assert axis.unit == ref_axis.unit, ( 379 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 380 f" but {axis.unit}!={ref_axis.unit}" 381 ) 382 if ref_size is None: 383 if isinstance(ref_axis.size, (int, float)): 384 ref_size = ref_axis.size 385 elif isinstance(ref_axis.size, ParameterizedSize): 386 ref_size = ref_axis.size.get_size(n) 387 elif isinstance(ref_axis.size, DataDependentSize): 388 raise ValueError( 389 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 390 ) 391 elif isinstance(ref_axis.size, SizeReference): 392 raise ValueError( 393 "Reference axis referenced in `SizeReference` may not be sized by a" 394 + " `SizeReference` itself." 395 ) 396 else: 397 assert_never(ref_axis.size) 398 399 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 400 401 @staticmethod 402 def _get_unit( 403 axis: Union[ 404 ChannelAxis, 405 IndexInputAxis, 406 IndexOutputAxis, 407 TimeInputAxis, 408 SpaceInputAxis, 409 TimeOutputAxis, 410 TimeOutputAxisWithHalo, 411 SpaceOutputAxis, 412 SpaceOutputAxisWithHalo, 413 ], 414 ): 415 return axis.unit 416 417 418class AxisBase(NodeWithExplicitlySetFields): 419 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"}) 420 421 id: AxisId 422 """An axis id unique across all axes of one tensor.""" 423 424 description: Annotated[str, MaxLen(128)] = "" 425 426 427class WithHalo(Node): 428 halo: Annotated[int, Ge(1)] 429 """The halo should be cropped from the output tensor to avoid boundary effects. 430 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 431 To document a halo that is already cropped by the model use `size.offset` instead.""" 432 433 size: Annotated[ 434 SizeReference, 435 Field( 436 examples=[ 437 10, 438 SizeReference( 439 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 440 ).model_dump(mode="json"), 441 ] 442 ), 443 ] 444 """reference to another axis with an optional offset (see `SizeReference`)""" 445 446 447BATCH_AXIS_ID = AxisId("batch") 448 449 450class BatchAxis(AxisBase): 451 type: Literal["batch"] = "batch" 452 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 453 size: Optional[Literal[1]] = None 454 """The batch size may be fixed to 1, 455 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 456 457 @property 458 def scale(self): 459 return 1.0 460 461 @property 462 def concatenable(self): 463 return True 464 465 @property 466 def unit(self): 467 return None 468 469 470class ChannelAxis(AxisBase): 471 type: Literal["channel"] = "channel" 472 id: NonBatchAxisId = AxisId("channel") 473 channel_names: NotEmpty[List[Identifier]] 474 475 @property 476 def size(self) -> int: 477 return len(self.channel_names) 478 479 @property 480 def concatenable(self): 481 return False 482 483 @property 484 def scale(self) -> float: 485 return 1.0 486 487 @property 488 def unit(self): 489 return None 490 491 492class IndexAxisBase(AxisBase): 493 type: Literal["index"] = "index" 494 id: NonBatchAxisId = AxisId("index") 495 496 @property 497 def scale(self) -> float: 498 return 1.0 499 500 @property 501 def unit(self): 502 return None 503 504 505class _WithInputAxisSize(Node): 506 size: Annotated[ 507 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 508 Field( 509 examples=[ 510 10, 511 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 512 SizeReference( 513 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 514 ).model_dump(mode="json"), 515 ] 516 ), 517 ] 518 """The size/length of this axis can be specified as 519 - fixed integer 520 - parameterized series of valid sizes (`ParameterizedSize`) 521 - reference to another axis with an optional offset (`SizeReference`) 522 """ 523 524 525class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 526 concatenable: bool = False 527 """If a model has a `concatenable` input axis, it can be processed blockwise, 528 splitting a longer sample axis into blocks matching its input tensor description. 529 Output axes are concatenable if they have a `SizeReference` to a concatenable 530 input axis. 531 """ 532 533 534class IndexOutputAxis(IndexAxisBase): 535 size: Annotated[ 536 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 537 Field( 538 examples=[ 539 10, 540 SizeReference( 541 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 542 ).model_dump(mode="json"), 543 ] 544 ), 545 ] 546 """The size/length of this axis can be specified as 547 - fixed integer 548 - reference to another axis with an optional offset (`SizeReference`) 549 - data dependent size using `DataDependentSize` (size is only known after model inference) 550 """ 551 552 553class TimeAxisBase(AxisBase): 554 type: Literal["time"] = "time" 555 id: NonBatchAxisId = AxisId("time") 556 unit: Optional[TimeUnit] = None 557 scale: Annotated[float, Gt(0)] = 1.0 558 559 560class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 561 concatenable: bool = False 562 """If a model has a `concatenable` input axis, it can be processed blockwise, 563 splitting a longer sample axis into blocks matching its input tensor description. 564 Output axes are concatenable if they have a `SizeReference` to a concatenable 565 input axis. 566 """ 567 568 569class SpaceAxisBase(AxisBase): 570 type: Literal["space"] = "space" 571 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 572 unit: Optional[SpaceUnit] = None 573 scale: Annotated[float, Gt(0)] = 1.0 574 575 576class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 577 concatenable: bool = False 578 """If a model has a `concatenable` input axis, it can be processed blockwise, 579 splitting a longer sample axis into blocks matching its input tensor description. 580 Output axes are concatenable if they have a `SizeReference` to a concatenable 581 input axis. 582 """ 583 584 585_InputAxisUnion = Union[ 586 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 587] 588InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 589 590 591class _WithOutputAxisSize(Node): 592 size: Annotated[ 593 Union[Annotated[int, Gt(0)], SizeReference], 594 Field( 595 examples=[ 596 10, 597 SizeReference( 598 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 599 ).model_dump(mode="json"), 600 ] 601 ), 602 ] 603 """The size/length of this axis can be specified as 604 - fixed integer 605 - reference to another axis with an optional offset (see `SizeReference`) 606 """ 607 608 609class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 610 pass 611 612 613class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 614 pass 615 616 617def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 618 if isinstance(v, dict): 619 return "with_halo" if "halo" in v else "wo_halo" 620 else: 621 return "with_halo" if hasattr(v, "halo") else "wo_halo" 622 623 624_TimeOutputAxisUnion = Annotated[ 625 Union[ 626 Annotated[TimeOutputAxis, Tag("wo_halo")], 627 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 628 ], 629 Discriminator(_get_halo_axis_discriminator_value), 630] 631 632 633class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 634 pass 635 636 637class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 638 pass 639 640 641_SpaceOutputAxisUnion = Annotated[ 642 Union[ 643 Annotated[SpaceOutputAxis, Tag("wo_halo")], 644 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 645 ], 646 Discriminator(_get_halo_axis_discriminator_value), 647] 648 649 650_OutputAxisUnion = Union[ 651 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 652] 653OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 654 655AnyAxis = Union[InputAxis, OutputAxis] 656 657TVs = Union[ 658 NotEmpty[List[int]], 659 NotEmpty[List[float]], 660 NotEmpty[List[bool]], 661 NotEmpty[List[str]], 662] 663 664 665NominalOrOrdinalDType = Literal[ 666 "float32", 667 "float64", 668 "uint8", 669 "int8", 670 "uint16", 671 "int16", 672 "uint32", 673 "int32", 674 "uint64", 675 "int64", 676 "bool", 677] 678 679 680class NominalOrOrdinalDataDescr(Node): 681 values: TVs 682 """A fixed set of nominal or an ascending sequence of ordinal values. 683 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'. 684 String `values` are interpreted as labels for tensor values 0, ..., N. 685 Note: as YAML 1.2 does not natively support a "set" datatype, 686 nominal values should be given as a sequence (aka list/array) as well. 687 """ 688 689 type: Annotated[ 690 NominalOrOrdinalDType, 691 Field( 692 examples=[ 693 "float32", 694 "uint8", 695 "uint16", 696 "int64", 697 "bool", 698 ], 699 ), 700 ] = "uint8" 701 702 @model_validator(mode="after") 703 def _validate_values_match_type( 704 self, 705 ) -> Self: 706 incompatible: List[Any] = [] 707 for v in self.values: 708 if self.type == "bool": 709 if not isinstance(v, bool): 710 incompatible.append(v) 711 elif self.type in DTYPE_LIMITS: 712 if ( 713 isinstance(v, (int, float)) 714 and ( 715 v < DTYPE_LIMITS[self.type].min 716 or v > DTYPE_LIMITS[self.type].max 717 ) 718 or (isinstance(v, str) and "uint" not in self.type) 719 or (isinstance(v, float) and "int" in self.type) 720 ): 721 incompatible.append(v) 722 else: 723 incompatible.append(v) 724 725 if len(incompatible) == 5: 726 incompatible.append("...") 727 break 728 729 if incompatible: 730 raise ValueError( 731 f"data type '{self.type}' incompatible with values {incompatible}" 732 ) 733 734 return self 735 736 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 737 738 @property 739 def range(self): 740 if isinstance(self.values[0], str): 741 return 0, len(self.values) - 1 742 else: 743 return min(self.values), max(self.values) 744 745 746IntervalOrRatioDType = Literal[ 747 "float32", 748 "float64", 749 "uint8", 750 "int8", 751 "uint16", 752 "int16", 753 "uint32", 754 "int32", 755 "uint64", 756 "int64", 757] 758 759 760class IntervalOrRatioDataDescr(Node): 761 type: Annotated[ # todo: rename to dtype 762 IntervalOrRatioDType, 763 Field( 764 examples=["float32", "float64", "uint8", "uint16"], 765 ), 766 ] = "float32" 767 range: Tuple[Optional[float], Optional[float]] = ( 768 None, 769 None, 770 ) 771 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 772 `None` corresponds to min/max of what can be expressed by `data_type`.""" 773 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 774 scale: float = 1.0 775 """Scale for data on an interval (or ratio) scale.""" 776 offset: Optional[float] = None 777 """Offset for data on a ratio scale.""" 778 779 780TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 781 782 783class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 784 """processing base class""" 785 786 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field 787 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) 788 789 790class BinarizeKwargs(ProcessingKwargs): 791 """key word arguments for `BinarizeDescr`""" 792 793 threshold: float 794 """The fixed threshold""" 795 796 797class BinarizeAlongAxisKwargs(ProcessingKwargs): 798 """key word arguments for `BinarizeDescr`""" 799 800 threshold: NotEmpty[List[float]] 801 """The fixed threshold values along `axis`""" 802 803 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 804 """The `threshold` axis""" 805 806 807class BinarizeDescr(ProcessingDescrBase): 808 """Binarize the tensor with a fixed threshold. 809 810 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 811 will be set to one, values below the threshold to zero. 812 813 Examples: 814 - in YAML 815 ```yaml 816 postprocessing: 817 - id: binarize 818 kwargs: 819 axis: 'channel' 820 threshold: [0.25, 0.5, 0.75] 821 ``` 822 - in Python: 823 >>> postprocessing = [BinarizeDescr( 824 ... kwargs=BinarizeAlongAxisKwargs( 825 ... axis=AxisId('channel'), 826 ... threshold=[0.25, 0.5, 0.75], 827 ... ) 828 ... )] 829 """ 830 831 id: Literal["binarize"] = "binarize" 832 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 833 834 835class ClipDescr(ProcessingDescrBase): 836 """Set tensor values below min to min and above max to max. 837 838 See `ScaleRangeDescr` for examples. 839 """ 840 841 id: Literal["clip"] = "clip" 842 kwargs: ClipKwargs 843 844 845class EnsureDtypeKwargs(ProcessingKwargs): 846 """key word arguments for `EnsureDtypeDescr`""" 847 848 dtype: Literal[ 849 "float32", 850 "float64", 851 "uint8", 852 "int8", 853 "uint16", 854 "int16", 855 "uint32", 856 "int32", 857 "uint64", 858 "int64", 859 "bool", 860 ] 861 862 863class EnsureDtypeDescr(ProcessingDescrBase): 864 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 865 866 This can for example be used to ensure the inner neural network model gets a 867 different input tensor data type than the fully described bioimage.io model does. 868 869 Examples: 870 The described bioimage.io model (incl. preprocessing) accepts any 871 float32-compatible tensor, normalizes it with percentiles and clipping and then 872 casts it to uint8, which is what the neural network in this example expects. 873 - in YAML 874 ```yaml 875 inputs: 876 - data: 877 type: float32 # described bioimage.io model is compatible with any float32 input tensor 878 preprocessing: 879 - id: scale_range 880 kwargs: 881 axes: ['y', 'x'] 882 max_percentile: 99.8 883 min_percentile: 5.0 884 - id: clip 885 kwargs: 886 min: 0.0 887 max: 1.0 888 - id: ensure_dtype 889 kwargs: 890 dtype: uint8 891 ``` 892 - in Python: 893 >>> preprocessing = [ 894 ... ScaleRangeDescr( 895 ... kwargs=ScaleRangeKwargs( 896 ... axes= (AxisId('y'), AxisId('x')), 897 ... max_percentile= 99.8, 898 ... min_percentile= 5.0, 899 ... ) 900 ... ), 901 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 902 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 903 ... ] 904 """ 905 906 id: Literal["ensure_dtype"] = "ensure_dtype" 907 kwargs: EnsureDtypeKwargs 908 909 910class ScaleLinearKwargs(ProcessingKwargs): 911 """Key word arguments for `ScaleLinearDescr`""" 912 913 gain: float = 1.0 914 """multiplicative factor""" 915 916 offset: float = 0.0 917 """additive term""" 918 919 @model_validator(mode="after") 920 def _validate(self) -> Self: 921 if self.gain == 1.0 and self.offset == 0.0: 922 raise ValueError( 923 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 924 + " != 0.0." 925 ) 926 927 return self 928 929 930class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 931 """Key word arguments for `ScaleLinearDescr`""" 932 933 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 934 """The axis of of gains/offsets values.""" 935 936 gain: Union[float, NotEmpty[List[float]]] = 1.0 937 """multiplicative factor""" 938 939 offset: Union[float, NotEmpty[List[float]]] = 0.0 940 """additive term""" 941 942 @model_validator(mode="after") 943 def _validate(self) -> Self: 944 945 if isinstance(self.gain, list): 946 if isinstance(self.offset, list): 947 if len(self.gain) != len(self.offset): 948 raise ValueError( 949 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 950 ) 951 else: 952 self.offset = [float(self.offset)] * len(self.gain) 953 elif isinstance(self.offset, list): 954 self.gain = [float(self.gain)] * len(self.offset) 955 else: 956 raise ValueError( 957 "Do not specify an `axis` for scalar gain and offset values." 958 ) 959 960 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 961 raise ValueError( 962 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 963 + " != 0.0." 964 ) 965 966 return self 967 968 969class ScaleLinearDescr(ProcessingDescrBase): 970 """Fixed linear scaling. 971 972 Examples: 973 1. Scale with scalar gain and offset 974 - in YAML 975 ```yaml 976 preprocessing: 977 - id: scale_linear 978 kwargs: 979 gain: 2.0 980 offset: 3.0 981 ``` 982 - in Python: 983 >>> preprocessing = [ 984 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 985 ... ] 986 987 2. Independent scaling along an axis 988 - in YAML 989 ```yaml 990 preprocessing: 991 - id: scale_linear 992 kwargs: 993 axis: 'channel' 994 gain: [1.0, 2.0, 3.0] 995 ``` 996 - in Python: 997 >>> preprocessing = [ 998 ... ScaleLinearDescr( 999 ... kwargs=ScaleLinearAlongAxisKwargs( 1000 ... axis=AxisId("channel"), 1001 ... gain=[1.0, 2.0, 3.0], 1002 ... ) 1003 ... ) 1004 ... ] 1005 1006 """ 1007 1008 id: Literal["scale_linear"] = "scale_linear" 1009 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1010 1011 1012class SigmoidDescr(ProcessingDescrBase): 1013 """The logistic sigmoid funciton, a.k.a. expit function. 1014 1015 Examples: 1016 - in YAML 1017 ```yaml 1018 postprocessing: 1019 - id: sigmoid 1020 ``` 1021 - in Python: 1022 >>> postprocessing = [SigmoidDescr()] 1023 """ 1024 1025 id: Literal["sigmoid"] = "sigmoid" 1026 1027 @property 1028 def kwargs(self) -> ProcessingKwargs: 1029 """empty kwargs""" 1030 return ProcessingKwargs() 1031 1032 1033class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1034 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1035 1036 mean: float 1037 """The mean value to normalize with.""" 1038 1039 std: Annotated[float, Ge(1e-6)] 1040 """The standard deviation value to normalize with.""" 1041 1042 1043class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1044 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1045 1046 mean: NotEmpty[List[float]] 1047 """The mean value(s) to normalize with.""" 1048 1049 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1050 """The standard deviation value(s) to normalize with. 1051 Size must match `mean` values.""" 1052 1053 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1054 """The axis of the mean/std values to normalize each entry along that dimension 1055 separately.""" 1056 1057 @model_validator(mode="after") 1058 def _mean_and_std_match(self) -> Self: 1059 if len(self.mean) != len(self.std): 1060 raise ValueError( 1061 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1062 + " must match." 1063 ) 1064 1065 return self 1066 1067 1068class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1069 """Subtract a given mean and divide by the standard deviation. 1070 1071 Normalize with fixed, precomputed values for 1072 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1073 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1074 axes. 1075 1076 Examples: 1077 1. scalar value for whole tensor 1078 - in YAML 1079 ```yaml 1080 preprocessing: 1081 - id: fixed_zero_mean_unit_variance 1082 kwargs: 1083 mean: 103.5 1084 std: 13.7 1085 ``` 1086 - in Python 1087 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1088 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1089 ... )] 1090 1091 2. independently along an axis 1092 - in YAML 1093 ```yaml 1094 preprocessing: 1095 - id: fixed_zero_mean_unit_variance 1096 kwargs: 1097 axis: channel 1098 mean: [101.5, 102.5, 103.5] 1099 std: [11.7, 12.7, 13.7] 1100 ``` 1101 - in Python 1102 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1103 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1104 ... axis=AxisId("channel"), 1105 ... mean=[101.5, 102.5, 103.5], 1106 ... std=[11.7, 12.7, 13.7], 1107 ... ) 1108 ... )] 1109 """ 1110 1111 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1112 kwargs: Union[ 1113 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1114 ] 1115 1116 1117class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1118 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1119 1120 axes: Annotated[ 1121 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1122 ] = None 1123 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1124 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1125 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1126 To normalize each sample independently leave out the 'batch' axis. 1127 Default: Scale all axes jointly.""" 1128 1129 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1130 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1131 1132 1133class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1134 """Subtract mean and divide by variance. 1135 1136 Examples: 1137 Subtract tensor mean and variance 1138 - in YAML 1139 ```yaml 1140 preprocessing: 1141 - id: zero_mean_unit_variance 1142 ``` 1143 - in Python 1144 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1145 """ 1146 1147 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1148 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1149 default_factory=ZeroMeanUnitVarianceKwargs 1150 ) 1151 1152 1153class ScaleRangeKwargs(ProcessingKwargs): 1154 """key word arguments for `ScaleRangeDescr` 1155 1156 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1157 this processing step normalizes data to the [0, 1] intervall. 1158 For other percentiles the normalized values will partially be outside the [0, 1] 1159 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1160 normalized values to a range. 1161 """ 1162 1163 axes: Annotated[ 1164 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1165 ] = None 1166 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1167 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1168 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1169 To normalize samples independently, leave out the "batch" axis. 1170 Default: Scale all axes jointly.""" 1171 1172 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1173 """The lower percentile used to determine the value to align with zero.""" 1174 1175 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1176 """The upper percentile used to determine the value to align with one. 1177 Has to be bigger than `min_percentile`. 1178 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1179 accepting percentiles specified in the range 0.0 to 1.0.""" 1180 1181 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1182 """Epsilon for numeric stability. 1183 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1184 with `v_lower,v_upper` values at the respective percentiles.""" 1185 1186 reference_tensor: Optional[TensorId] = None 1187 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1188 For any tensor in `inputs` only input tensor references are allowed.""" 1189 1190 @field_validator("max_percentile", mode="after") 1191 @classmethod 1192 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1193 if (min_p := info.data["min_percentile"]) >= value: 1194 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1195 1196 return value 1197 1198 1199class ScaleRangeDescr(ProcessingDescrBase): 1200 """Scale with percentiles. 1201 1202 Examples: 1203 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1204 - in YAML 1205 ```yaml 1206 preprocessing: 1207 - id: scale_range 1208 kwargs: 1209 axes: ['y', 'x'] 1210 max_percentile: 99.8 1211 min_percentile: 5.0 1212 ``` 1213 - in Python 1214 >>> preprocessing = [ 1215 ... ScaleRangeDescr( 1216 ... kwargs=ScaleRangeKwargs( 1217 ... axes= (AxisId('y'), AxisId('x')), 1218 ... max_percentile= 99.8, 1219 ... min_percentile= 5.0, 1220 ... ) 1221 ... ), 1222 ... ClipDescr( 1223 ... kwargs=ClipKwargs( 1224 ... min=0.0, 1225 ... max=1.0, 1226 ... ) 1227 ... ), 1228 ... ] 1229 1230 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1231 - in YAML 1232 ```yaml 1233 preprocessing: 1234 - id: scale_range 1235 kwargs: 1236 axes: ['y', 'x'] 1237 max_percentile: 99.8 1238 min_percentile: 5.0 1239 - id: scale_range 1240 - id: clip 1241 kwargs: 1242 min: 0.0 1243 max: 1.0 1244 ``` 1245 - in Python 1246 >>> preprocessing = [ScaleRangeDescr( 1247 ... kwargs=ScaleRangeKwargs( 1248 ... axes= (AxisId('y'), AxisId('x')), 1249 ... max_percentile= 99.8, 1250 ... min_percentile= 5.0, 1251 ... ) 1252 ... )] 1253 1254 """ 1255 1256 id: Literal["scale_range"] = "scale_range" 1257 kwargs: ScaleRangeKwargs 1258 1259 1260class ScaleMeanVarianceKwargs(ProcessingKwargs): 1261 """key word arguments for `ScaleMeanVarianceKwargs`""" 1262 1263 reference_tensor: TensorId 1264 """Name of tensor to match.""" 1265 1266 axes: Annotated[ 1267 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1268 ] = None 1269 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1270 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1271 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1272 To normalize samples independently, leave out the 'batch' axis. 1273 Default: Scale all axes jointly.""" 1274 1275 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1276 """Epsilon for numeric stability: 1277 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1278 1279 1280class ScaleMeanVarianceDescr(ProcessingDescrBase): 1281 """Scale a tensor's data distribution to match another tensor's mean/std. 1282 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1283 """ 1284 1285 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1286 kwargs: ScaleMeanVarianceKwargs 1287 1288 1289PreprocessingDescr = Annotated[ 1290 Union[ 1291 BinarizeDescr, 1292 ClipDescr, 1293 EnsureDtypeDescr, 1294 ScaleLinearDescr, 1295 SigmoidDescr, 1296 FixedZeroMeanUnitVarianceDescr, 1297 ZeroMeanUnitVarianceDescr, 1298 ScaleRangeDescr, 1299 ], 1300 Discriminator("id"), 1301] 1302PostprocessingDescr = Annotated[ 1303 Union[ 1304 BinarizeDescr, 1305 ClipDescr, 1306 EnsureDtypeDescr, 1307 ScaleLinearDescr, 1308 SigmoidDescr, 1309 FixedZeroMeanUnitVarianceDescr, 1310 ZeroMeanUnitVarianceDescr, 1311 ScaleRangeDescr, 1312 ScaleMeanVarianceDescr, 1313 ], 1314 Discriminator("id"), 1315] 1316 1317IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1318 1319 1320class TensorDescrBase(Node, Generic[IO_AxisT]): 1321 id: TensorId 1322 """Tensor id. No duplicates are allowed.""" 1323 1324 description: Annotated[str, MaxLen(128)] = "" 1325 """free text description""" 1326 1327 axes: NotEmpty[Sequence[IO_AxisT]] 1328 """tensor axes""" 1329 1330 @property 1331 def shape(self): 1332 return tuple(a.size for a in self.axes) 1333 1334 @field_validator("axes", mode="after", check_fields=False) 1335 @classmethod 1336 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1337 batch_axes = [a for a in axes if a.type == "batch"] 1338 if len(batch_axes) > 1: 1339 raise ValueError( 1340 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1341 ) 1342 1343 seen_ids: Set[AxisId] = set() 1344 duplicate_axes_ids: Set[AxisId] = set() 1345 for a in axes: 1346 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1347 1348 if duplicate_axes_ids: 1349 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1350 1351 return axes 1352 1353 test_tensor: FileDescr 1354 """An example tensor to use for testing. 1355 Using the model with the test input tensors is expected to yield the test output tensors. 1356 Each test tensor has be a an ndarray in the 1357 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1358 The file extension must be '.npy'.""" 1359 1360 sample_tensor: Optional[FileDescr] = None 1361 """A sample tensor to illustrate a possible input/output for the model, 1362 The sample image primarily serves to inform a human user about an example use case 1363 and is typically stored as .hdf5, .png or .tiff. 1364 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1365 (numpy's `.npy` format is not supported). 1366 The image dimensionality has to match the number of axes specified in this tensor description. 1367 """ 1368 1369 @model_validator(mode="after") 1370 def _validate_sample_tensor(self) -> Self: 1371 if ( 1372 self.sample_tensor is None 1373 or not validation_context_var.get().perform_io_checks 1374 ): 1375 return self 1376 1377 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1378 tensor: NDArray[Any] = imread( 1379 local.path.read_bytes(), 1380 extension=PurePosixPath(local.original_file_name).suffix, 1381 ) 1382 n_dims = len(tensor.squeeze().shape) 1383 n_dims_min = n_dims_max = len(self.axes) 1384 1385 for a in self.axes: 1386 if isinstance(a, BatchAxis): 1387 n_dims_min -= 1 1388 elif isinstance(a.size, int): 1389 if a.size == 1: 1390 n_dims_min -= 1 1391 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1392 if a.size.min == 1: 1393 n_dims_min -= 1 1394 elif isinstance(a.size, SizeReference): 1395 if a.size.offset < 2: 1396 # size reference may result in singleton axis 1397 n_dims_min -= 1 1398 else: 1399 assert_never(a.size) 1400 1401 n_dims_min = max(0, n_dims_min) 1402 if n_dims < n_dims_min or n_dims > n_dims_max: 1403 raise ValueError( 1404 f"Expected sample tensor to have {n_dims_min} to" 1405 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1406 ) 1407 1408 return self 1409 1410 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1411 IntervalOrRatioDataDescr() 1412 ) 1413 """Description of the tensor's data values, optionally per channel. 1414 If specified per channel, the data `type` needs to match across channels.""" 1415 1416 @property 1417 def dtype( 1418 self, 1419 ) -> Literal[ 1420 "float32", 1421 "float64", 1422 "uint8", 1423 "int8", 1424 "uint16", 1425 "int16", 1426 "uint32", 1427 "int32", 1428 "uint64", 1429 "int64", 1430 "bool", 1431 ]: 1432 """dtype as specified under `data.type` or `data[i].type`""" 1433 if isinstance(self.data, collections.abc.Sequence): 1434 return self.data[0].type 1435 else: 1436 return self.data.type 1437 1438 @field_validator("data", mode="after") 1439 @classmethod 1440 def _check_data_type_across_channels( 1441 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1442 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1443 if not isinstance(value, list): 1444 return value 1445 1446 dtypes = {t.type for t in value} 1447 if len(dtypes) > 1: 1448 raise ValueError( 1449 "Tensor data descriptions per channel need to agree in their data" 1450 + f" `type`, but found {dtypes}." 1451 ) 1452 1453 return value 1454 1455 @model_validator(mode="after") 1456 def _check_data_matches_channelaxis(self) -> Self: 1457 if not isinstance(self.data, (list, tuple)): 1458 return self 1459 1460 for a in self.axes: 1461 if isinstance(a, ChannelAxis): 1462 size = a.size 1463 assert isinstance(size, int) 1464 break 1465 else: 1466 return self 1467 1468 if len(self.data) != size: 1469 raise ValueError( 1470 f"Got tensor data descriptions for {len(self.data)} channels, but" 1471 + f" '{a.id}' axis has size {size}." 1472 ) 1473 1474 return self 1475 1476 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1477 if len(array.shape) != len(self.axes): 1478 raise ValueError( 1479 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1480 + f" incompatible with {len(self.axes)} axes." 1481 ) 1482 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1483 1484 1485class InputTensorDescr(TensorDescrBase[InputAxis]): 1486 id: TensorId = TensorId("input") 1487 """Input tensor id. 1488 No duplicates are allowed across all inputs and outputs.""" 1489 1490 optional: bool = False 1491 """indicates that this tensor may be `None`""" 1492 1493 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1494 """Description of how this input should be preprocessed. 1495 1496 notes: 1497 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1498 to ensure an input tensor's data type matches the input tensor's data description. 1499 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1500 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1501 changing the data type. 1502 """ 1503 1504 @model_validator(mode="after") 1505 def _validate_preprocessing_kwargs(self) -> Self: 1506 axes_ids = [a.id for a in self.axes] 1507 for p in self.preprocessing: 1508 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1509 if kwargs_axes is None: 1510 continue 1511 1512 if not isinstance(kwargs_axes, collections.abc.Sequence): 1513 raise ValueError( 1514 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1515 ) 1516 1517 if any(a not in axes_ids for a in kwargs_axes): 1518 raise ValueError( 1519 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1520 ) 1521 1522 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1523 dtype = self.data.type 1524 else: 1525 dtype = self.data[0].type 1526 1527 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1528 if not self.preprocessing or not isinstance( 1529 self.preprocessing[0], EnsureDtypeDescr 1530 ): 1531 self.preprocessing.insert( 1532 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1533 ) 1534 1535 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1536 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1537 self.preprocessing.append( 1538 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1539 ) 1540 1541 return self 1542 1543 1544def convert_axes( 1545 axes: str, 1546 *, 1547 shape: Union[ 1548 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1549 ], 1550 tensor_type: Literal["input", "output"], 1551 halo: Optional[Sequence[int]], 1552 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1553): 1554 ret: List[AnyAxis] = [] 1555 for i, a in enumerate(axes): 1556 axis_type = _AXIS_TYPE_MAP.get(a, a) 1557 if axis_type == "batch": 1558 ret.append(BatchAxis()) 1559 continue 1560 1561 scale = 1.0 1562 if isinstance(shape, _ParameterizedInputShape_v0_4): 1563 if shape.step[i] == 0: 1564 size = shape.min[i] 1565 else: 1566 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1567 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1568 ref_t = str(shape.reference_tensor) 1569 if ref_t.count(".") == 1: 1570 t_id, orig_a_id = ref_t.split(".") 1571 else: 1572 t_id = ref_t 1573 orig_a_id = a 1574 1575 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1576 if not (orig_scale := shape.scale[i]): 1577 # old way to insert a new axis dimension 1578 size = int(2 * shape.offset[i]) 1579 else: 1580 scale = 1 / orig_scale 1581 if axis_type in ("channel", "index"): 1582 # these axes no longer have a scale 1583 offset_from_scale = orig_scale * size_refs.get( 1584 _TensorName_v0_4(t_id), {} 1585 ).get(orig_a_id, 0) 1586 else: 1587 offset_from_scale = 0 1588 size = SizeReference( 1589 tensor_id=TensorId(t_id), 1590 axis_id=AxisId(a_id), 1591 offset=int(offset_from_scale + 2 * shape.offset[i]), 1592 ) 1593 else: 1594 size = shape[i] 1595 1596 if axis_type == "time": 1597 if tensor_type == "input": 1598 ret.append(TimeInputAxis(size=size, scale=scale)) 1599 else: 1600 assert not isinstance(size, ParameterizedSize) 1601 if halo is None: 1602 ret.append(TimeOutputAxis(size=size, scale=scale)) 1603 else: 1604 assert not isinstance(size, int) 1605 ret.append( 1606 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1607 ) 1608 1609 elif axis_type == "index": 1610 if tensor_type == "input": 1611 ret.append(IndexInputAxis(size=size)) 1612 else: 1613 if isinstance(size, ParameterizedSize): 1614 size = DataDependentSize(min=size.min) 1615 1616 ret.append(IndexOutputAxis(size=size)) 1617 elif axis_type == "channel": 1618 assert not isinstance(size, ParameterizedSize) 1619 if isinstance(size, SizeReference): 1620 warnings.warn( 1621 "Conversion of channel size from an implicit output shape may be" 1622 + " wrong" 1623 ) 1624 ret.append( 1625 ChannelAxis( 1626 channel_names=[ 1627 Identifier(f"channel{i}") for i in range(size.offset) 1628 ] 1629 ) 1630 ) 1631 else: 1632 ret.append( 1633 ChannelAxis( 1634 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1635 ) 1636 ) 1637 elif axis_type == "space": 1638 if tensor_type == "input": 1639 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1640 else: 1641 assert not isinstance(size, ParameterizedSize) 1642 if halo is None or halo[i] == 0: 1643 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1644 elif isinstance(size, int): 1645 raise NotImplementedError( 1646 f"output axis with halo and fixed size (here {size}) not allowed" 1647 ) 1648 else: 1649 ret.append( 1650 SpaceOutputAxisWithHalo( 1651 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1652 ) 1653 ) 1654 1655 return ret 1656 1657 1658_AXIS_TYPE_MAP = { 1659 "b": "batch", 1660 "t": "time", 1661 "i": "index", 1662 "c": "channel", 1663 "x": "space", 1664 "y": "space", 1665 "z": "space", 1666} 1667 1668_AXIS_ID_MAP = { 1669 "b": "batch", 1670 "t": "time", 1671 "i": "index", 1672 "c": "channel", 1673} 1674 1675 1676def _axes_letters_to_ids( 1677 axes: Optional[str], 1678) -> Optional[List[AxisId]]: 1679 if axes is None: 1680 return None 1681 return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)] 1682 1683 1684def _get_complement_v04_axis( 1685 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1686) -> Optional[AxisId]: 1687 if axes is None: 1688 return None 1689 1690 non_complement_axes = set(axes) | {"b"} 1691 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1692 if len(complement_axes) > 1: 1693 raise ValueError( 1694 f"Expected none or a single complement axis, but axes '{axes}' " 1695 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1696 ) 1697 1698 return None if not complement_axes else AxisId(complement_axes[0]) 1699 1700 1701def _convert_proc( 1702 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1703 tensor_axes: Sequence[str], 1704) -> Union[PreprocessingDescr, PostprocessingDescr]: 1705 if isinstance(p, _BinarizeDescr_v0_4): 1706 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1707 elif isinstance(p, _ClipDescr_v0_4): 1708 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1709 elif isinstance(p, _SigmoidDescr_v0_4): 1710 return SigmoidDescr() 1711 elif isinstance(p, _ScaleLinearDescr_v0_4): 1712 axes = _axes_letters_to_ids(p.kwargs.axes) 1713 if p.kwargs.axes is None: 1714 axis = None 1715 else: 1716 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1717 1718 if axis is None: 1719 assert not isinstance(p.kwargs.gain, list) 1720 assert not isinstance(p.kwargs.offset, list) 1721 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1722 else: 1723 kwargs = ScaleLinearAlongAxisKwargs( 1724 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1725 ) 1726 return ScaleLinearDescr(kwargs=kwargs) 1727 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1728 return ScaleMeanVarianceDescr( 1729 kwargs=ScaleMeanVarianceKwargs( 1730 axes=_axes_letters_to_ids(p.kwargs.axes), 1731 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1732 eps=p.kwargs.eps, 1733 ) 1734 ) 1735 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1736 if p.kwargs.mode == "fixed": 1737 mean = p.kwargs.mean 1738 std = p.kwargs.std 1739 assert mean is not None 1740 assert std is not None 1741 1742 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1743 1744 if axis is None: 1745 return FixedZeroMeanUnitVarianceDescr( 1746 kwargs=FixedZeroMeanUnitVarianceKwargs( 1747 mean=mean, std=std # pyright: ignore[reportArgumentType] 1748 ) 1749 ) 1750 else: 1751 if not isinstance(mean, list): 1752 mean = [float(mean)] 1753 if not isinstance(std, list): 1754 std = [float(std)] 1755 1756 return FixedZeroMeanUnitVarianceDescr( 1757 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1758 axis=axis, mean=mean, std=std 1759 ) 1760 ) 1761 1762 else: 1763 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1764 if p.kwargs.mode == "per_dataset": 1765 axes = [AxisId("batch")] + axes 1766 if not axes: 1767 axes = None 1768 return ZeroMeanUnitVarianceDescr( 1769 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1770 ) 1771 1772 elif isinstance(p, _ScaleRangeDescr_v0_4): 1773 return ScaleRangeDescr( 1774 kwargs=ScaleRangeKwargs( 1775 axes=_axes_letters_to_ids(p.kwargs.axes), 1776 min_percentile=p.kwargs.min_percentile, 1777 max_percentile=p.kwargs.max_percentile, 1778 eps=p.kwargs.eps, 1779 ) 1780 ) 1781 else: 1782 assert_never(p) 1783 1784 1785class _InputTensorConv( 1786 Converter[ 1787 _InputTensorDescr_v0_4, 1788 InputTensorDescr, 1789 ImportantFileSource, 1790 Optional[ImportantFileSource], 1791 Mapping[_TensorName_v0_4, Mapping[str, int]], 1792 ] 1793): 1794 def _convert( 1795 self, 1796 src: _InputTensorDescr_v0_4, 1797 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1798 test_tensor: ImportantFileSource, 1799 sample_tensor: Optional[ImportantFileSource], 1800 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1801 ) -> "InputTensorDescr | dict[str, Any]": 1802 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1803 src.axes, 1804 shape=src.shape, 1805 tensor_type="input", 1806 halo=None, 1807 size_refs=size_refs, 1808 ) 1809 prep: List[PreprocessingDescr] = [] 1810 for p in src.preprocessing: 1811 cp = _convert_proc(p, src.axes) 1812 assert not isinstance(cp, ScaleMeanVarianceDescr) 1813 prep.append(cp) 1814 1815 return tgt( 1816 axes=axes, 1817 id=TensorId(str(src.name)), 1818 test_tensor=FileDescr(source=test_tensor), 1819 sample_tensor=( 1820 None if sample_tensor is None else FileDescr(source=sample_tensor) 1821 ), 1822 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 1823 preprocessing=prep, 1824 ) 1825 1826 1827_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 1828 1829 1830class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1831 id: TensorId = TensorId("output") 1832 """Output tensor id. 1833 No duplicates are allowed across all inputs and outputs.""" 1834 1835 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1836 """Description of how this output should be postprocessed. 1837 1838 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1839 If not given this is added to cast to this tensor's `data.type`. 1840 """ 1841 1842 @model_validator(mode="after") 1843 def _validate_postprocessing_kwargs(self) -> Self: 1844 axes_ids = [a.id for a in self.axes] 1845 for p in self.postprocessing: 1846 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1847 if kwargs_axes is None: 1848 continue 1849 1850 if not isinstance(kwargs_axes, collections.abc.Sequence): 1851 raise ValueError( 1852 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1853 ) 1854 1855 if any(a not in axes_ids for a in kwargs_axes): 1856 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1857 1858 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1859 dtype = self.data.type 1860 else: 1861 dtype = self.data[0].type 1862 1863 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1864 if not self.postprocessing or not isinstance( 1865 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1866 ): 1867 self.postprocessing.append( 1868 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1869 ) 1870 return self 1871 1872 1873class _OutputTensorConv( 1874 Converter[ 1875 _OutputTensorDescr_v0_4, 1876 OutputTensorDescr, 1877 ImportantFileSource, 1878 Optional[ImportantFileSource], 1879 Mapping[_TensorName_v0_4, Mapping[str, int]], 1880 ] 1881): 1882 def _convert( 1883 self, 1884 src: _OutputTensorDescr_v0_4, 1885 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 1886 test_tensor: ImportantFileSource, 1887 sample_tensor: Optional[ImportantFileSource], 1888 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1889 ) -> "OutputTensorDescr | dict[str, Any]": 1890 # TODO: split convert_axes into convert_output_axes and convert_input_axes 1891 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1892 src.axes, 1893 shape=src.shape, 1894 tensor_type="output", 1895 halo=src.halo, 1896 size_refs=size_refs, 1897 ) 1898 data_descr: Dict[str, Any] = dict(type=src.data_type) 1899 if data_descr["type"] == "bool": 1900 data_descr["values"] = [False, True] 1901 1902 return tgt( 1903 axes=axes, 1904 id=TensorId(str(src.name)), 1905 test_tensor=FileDescr(source=test_tensor), 1906 sample_tensor=( 1907 None if sample_tensor is None else FileDescr(source=sample_tensor) 1908 ), 1909 data=data_descr, # pyright: ignore[reportArgumentType] 1910 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 1911 ) 1912 1913 1914_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 1915 1916 1917TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 1918 1919 1920def validate_tensors( 1921 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 1922 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor' 1923): 1924 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 1925 1926 def e_msg(d: TensorDescr): 1927 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 1928 1929 for descr, array in tensors.values(): 1930 try: 1931 axis_sizes = descr.get_axis_sizes_for_array(array) 1932 except ValueError as e: 1933 raise ValueError(f"{e_msg(descr)} {e}") 1934 else: 1935 all_tensor_axes[descr.id] = { 1936 a.id: (a, axis_sizes[a.id]) for a in descr.axes 1937 } 1938 1939 for descr, array in tensors.values(): 1940 if array.dtype.name != descr.dtype: 1941 raise ValueError( 1942 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 1943 + f" match described dtype '{descr.dtype}'" 1944 ) 1945 1946 for a in descr.axes: 1947 actual_size = all_tensor_axes[descr.id][a.id][1] 1948 if a.size is None: 1949 continue 1950 1951 if isinstance(a.size, int): 1952 if actual_size != a.size: 1953 raise ValueError( 1954 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 1955 + f"has incompatible size {actual_size}, expected {a.size}" 1956 ) 1957 elif isinstance(a.size, ParameterizedSize): 1958 _ = a.size.validate_size(actual_size) 1959 elif isinstance(a.size, DataDependentSize): 1960 _ = a.size.validate_size(actual_size) 1961 elif isinstance(a.size, SizeReference): 1962 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 1963 if ref_tensor_axes is None: 1964 raise ValueError( 1965 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 1966 + f" reference '{a.size.tensor_id}'" 1967 ) 1968 1969 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 1970 if ref_axis is None or ref_size is None: 1971 raise ValueError( 1972 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 1973 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 1974 ) 1975 1976 if a.unit != ref_axis.unit: 1977 raise ValueError( 1978 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 1979 + " axis and reference axis to have the same `unit`, but" 1980 + f" {a.unit}!={ref_axis.unit}" 1981 ) 1982 1983 if actual_size != ( 1984 expected_size := ( 1985 ref_size * ref_axis.scale / a.scale + a.size.offset 1986 ) 1987 ): 1988 raise ValueError( 1989 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 1990 + f" {actual_size} invalid for referenced size {ref_size};" 1991 + f" expected {expected_size}" 1992 ) 1993 else: 1994 assert_never(a.size) 1995 1996 1997class EnvironmentFileDescr(FileDescr): 1998 source: Annotated[ 1999 ImportantFileSource, 2000 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2001 Field( 2002 examples=["environment.yaml"], 2003 ), 2004 ] 2005 """∈📦 Conda environment file. 2006 Allows to specify custom dependencies, see conda docs: 2007 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2008 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2009 """ 2010 2011 2012class _ArchitectureCallableDescr(Node): 2013 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2014 """Identifier of the callable that returns a torch.nn.Module instance.""" 2015 2016 kwargs: Dict[str, YamlValue] = Field(default_factory=dict) 2017 """key word arguments for the `callable`""" 2018 2019 2020class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2021 pass 2022 2023 2024class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2025 import_from: str 2026 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2027 2028 2029ArchitectureDescr = Annotated[ 2030 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], 2031 Field(union_mode="left_to_right"), 2032] 2033 2034 2035class _ArchFileConv( 2036 Converter[ 2037 _CallableFromFile_v0_4, 2038 ArchitectureFromFileDescr, 2039 Optional[Sha256], 2040 Dict[str, Any], 2041 ] 2042): 2043 def _convert( 2044 self, 2045 src: _CallableFromFile_v0_4, 2046 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2047 sha256: Optional[Sha256], 2048 kwargs: Dict[str, Any], 2049 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2050 if src.startswith("http") and src.count(":") == 2: 2051 http, source, callable_ = src.split(":") 2052 source = ":".join((http, source)) 2053 elif not src.startswith("http") and src.count(":") == 1: 2054 source, callable_ = src.split(":") 2055 else: 2056 source = str(src) 2057 callable_ = str(src) 2058 return tgt( 2059 callable=Identifier(callable_), 2060 source=cast(ImportantFileSource, source), 2061 sha256=sha256, 2062 kwargs=kwargs, 2063 ) 2064 2065 2066_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2067 2068 2069class _ArchLibConv( 2070 Converter[ 2071 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2072 ] 2073): 2074 def _convert( 2075 self, 2076 src: _CallableFromDepencency_v0_4, 2077 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2078 kwargs: Dict[str, Any], 2079 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2080 *mods, callable_ = src.split(".") 2081 import_from = ".".join(mods) 2082 return tgt( 2083 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2084 ) 2085 2086 2087_arch_lib_conv = _ArchLibConv( 2088 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2089) 2090 2091 2092class WeightsEntryDescrBase(FileDescr): 2093 type: ClassVar[WeightsFormat] 2094 weights_format_name: ClassVar[str] # human readable 2095 2096 source: ImportantFileSource 2097 """∈📦 The weights file.""" 2098 2099 authors: Optional[List[Author]] = None 2100 """Authors 2101 Either the person(s) that have trained this model resulting in the original weights file. 2102 (If this is the initial weights entry, i.e. it does not have a `parent`) 2103 Or the person(s) who have converted the weights to this weights format. 2104 (If this is a child weight, i.e. it has a `parent` field) 2105 """ 2106 2107 parent: Annotated[ 2108 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2109 ] = None 2110 """The source weights these weights were converted from. 2111 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2112 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2113 All weight entries except one (the initial set of weights resulting from training the model), 2114 need to have this field.""" 2115 2116 @model_validator(mode="after") 2117 def check_parent_is_not_self(self) -> Self: 2118 if self.type == self.parent: 2119 raise ValueError("Weights entry can't be it's own parent.") 2120 2121 return self 2122 2123 2124class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2125 type = "keras_hdf5" 2126 weights_format_name: ClassVar[str] = "Keras HDF5" 2127 tensorflow_version: Version 2128 """TensorFlow version used to create these weights.""" 2129 2130 2131class OnnxWeightsDescr(WeightsEntryDescrBase): 2132 type = "onnx" 2133 weights_format_name: ClassVar[str] = "ONNX" 2134 opset_version: Annotated[int, Ge(7)] 2135 """ONNX opset version""" 2136 2137 2138class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2139 type = "pytorch_state_dict" 2140 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2141 architecture: ArchitectureDescr 2142 pytorch_version: Version 2143 """Version of the PyTorch library used. 2144 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2145 """ 2146 dependencies: Optional[EnvironmentFileDescr] = None 2147 """Custom depencies beyond pytorch. 2148 The conda environment file should include pytorch and any version pinning has to be compatible with 2149 `pytorch_version`. 2150 """ 2151 2152 2153class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2154 type = "tensorflow_js" 2155 weights_format_name: ClassVar[str] = "Tensorflow.js" 2156 tensorflow_version: Version 2157 """Version of the TensorFlow library used.""" 2158 2159 source: ImportantFileSource 2160 """∈📦 The multi-file weights. 2161 All required files/folders should be a zip archive.""" 2162 2163 2164class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2165 type = "tensorflow_saved_model_bundle" 2166 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2167 tensorflow_version: Version 2168 """Version of the TensorFlow library used.""" 2169 2170 dependencies: Optional[EnvironmentFileDescr] = None 2171 """Custom dependencies beyond tensorflow. 2172 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 2173 2174 source: ImportantFileSource 2175 """∈📦 The multi-file weights. 2176 All required files/folders should be a zip archive.""" 2177 2178 2179class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2180 type = "torchscript" 2181 weights_format_name: ClassVar[str] = "TorchScript" 2182 pytorch_version: Version 2183 """Version of the PyTorch library used.""" 2184 2185 2186class WeightsDescr(Node): 2187 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2188 onnx: Optional[OnnxWeightsDescr] = None 2189 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2190 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2191 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2192 None 2193 ) 2194 torchscript: Optional[TorchscriptWeightsDescr] = None 2195 2196 @model_validator(mode="after") 2197 def check_entries(self) -> Self: 2198 entries = {wtype for wtype, entry in self if entry is not None} 2199 2200 if not entries: 2201 raise ValueError("Missing weights entry") 2202 2203 entries_wo_parent = { 2204 wtype 2205 for wtype, entry in self 2206 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2207 } 2208 if len(entries_wo_parent) != 1: 2209 issue_warning( 2210 "Exactly one weights entry may not specify the `parent` field (got" 2211 + " {value}). That entry is considered the original set of model weights." 2212 + " Other weight formats are created through conversion of the orignal or" 2213 + " already converted weights. They have to reference the weights format" 2214 + " they were converted from as their `parent`.", 2215 value=len(entries_wo_parent), 2216 field="weights", 2217 ) 2218 2219 for wtype, entry in self: 2220 if entry is None: 2221 continue 2222 2223 assert hasattr(entry, "type") 2224 assert hasattr(entry, "parent") 2225 assert wtype == entry.type 2226 if ( 2227 entry.parent is not None and entry.parent not in entries 2228 ): # self reference checked for `parent` field 2229 raise ValueError( 2230 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2231 + f" formats: {entries}" 2232 ) 2233 2234 return self 2235 2236 def __getitem__( 2237 self, 2238 key: Literal[ 2239 "keras_hdf5", 2240 "onnx", 2241 "pytorch_state_dict", 2242 "tensorflow_js", 2243 "tensorflow_saved_model_bundle", 2244 "torchscript", 2245 ], 2246 ): 2247 if key == "keras_hdf5": 2248 ret = self.keras_hdf5 2249 elif key == "onnx": 2250 ret = self.onnx 2251 elif key == "pytorch_state_dict": 2252 ret = self.pytorch_state_dict 2253 elif key == "tensorflow_js": 2254 ret = self.tensorflow_js 2255 elif key == "tensorflow_saved_model_bundle": 2256 ret = self.tensorflow_saved_model_bundle 2257 elif key == "torchscript": 2258 ret = self.torchscript 2259 else: 2260 raise KeyError(key) 2261 2262 if ret is None: 2263 raise KeyError(key) 2264 2265 return ret 2266 2267 @property 2268 def available_formats(self): 2269 return { 2270 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2271 **({} if self.onnx is None else {"onnx": self.onnx}), 2272 **( 2273 {} 2274 if self.pytorch_state_dict is None 2275 else {"pytorch_state_dict": self.pytorch_state_dict} 2276 ), 2277 **( 2278 {} 2279 if self.tensorflow_js is None 2280 else {"tensorflow_js": self.tensorflow_js} 2281 ), 2282 **( 2283 {} 2284 if self.tensorflow_saved_model_bundle is None 2285 else { 2286 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2287 } 2288 ), 2289 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2290 } 2291 2292 @property 2293 def missing_formats(self): 2294 return { 2295 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2296 } 2297 2298 2299class ModelId(ResourceId): 2300 pass 2301 2302 2303class LinkedModel(LinkedResourceBase): 2304 """Reference to a bioimage.io model.""" 2305 2306 id: ModelId 2307 """A valid model `id` from the bioimage.io collection.""" 2308 2309 2310class _DataDepSize(NamedTuple): 2311 min: int 2312 max: Optional[int] 2313 2314 2315class _AxisSizes(NamedTuple): 2316 """the lenghts of all axes of model inputs and outputs""" 2317 2318 inputs: Dict[Tuple[TensorId, AxisId], int] 2319 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2320 2321 2322class _TensorSizes(NamedTuple): 2323 """_AxisSizes as nested dicts""" 2324 2325 inputs: Dict[TensorId, Dict[AxisId, int]] 2326 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2327 2328 2329class ModelDescr(GenericModelDescrBase): 2330 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2331 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2332 """ 2333 2334 format_version: Literal["0.5.3"] = "0.5.3" 2335 """Version of the bioimage.io model description specification used. 2336 When creating a new model always use the latest micro/patch version described here. 2337 The `format_version` is important for any consumer software to understand how to parse the fields. 2338 """ 2339 2340 type: Literal["model"] = "model" 2341 """Specialized resource type 'model'""" 2342 2343 id: Optional[ModelId] = None 2344 """bioimage.io-wide unique resource identifier 2345 assigned by bioimage.io; version **un**specific.""" 2346 2347 authors: NotEmpty[List[Author]] 2348 """The authors are the creators of the model RDF and the primary points of contact.""" 2349 2350 documentation: Annotated[ 2351 DocumentationSource, 2352 Field( 2353 examples=[ 2354 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2355 "README.md", 2356 ], 2357 ), 2358 ] 2359 """∈📦 URL or relative path to a markdown file with additional documentation. 2360 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2361 The documentation should include a '#[#] Validation' (sub)section 2362 with details on how to quantitatively validate the model on unseen data.""" 2363 2364 @field_validator("documentation", mode="after") 2365 @classmethod 2366 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2367 if not validation_context_var.get().perform_io_checks: 2368 return value 2369 2370 doc_path = download(value).path 2371 doc_content = doc_path.read_text(encoding="utf-8") 2372 assert isinstance(doc_content, str) 2373 if not re.match("#.*[vV]alidation", doc_content): 2374 issue_warning( 2375 "No '# Validation' (sub)section found in {value}.", 2376 value=value, 2377 field="documentation", 2378 ) 2379 2380 return value 2381 2382 inputs: NotEmpty[Sequence[InputTensorDescr]] 2383 """Describes the input tensors expected by this model.""" 2384 2385 @field_validator("inputs", mode="after") 2386 @classmethod 2387 def _validate_input_axes( 2388 cls, inputs: Sequence[InputTensorDescr] 2389 ) -> Sequence[InputTensorDescr]: 2390 input_size_refs = cls._get_axes_with_independent_size(inputs) 2391 2392 for i, ipt in enumerate(inputs): 2393 valid_independent_refs: Dict[ 2394 Tuple[TensorId, AxisId], 2395 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2396 ] = { 2397 **{ 2398 (ipt.id, a.id): (ipt, a, a.size) 2399 for a in ipt.axes 2400 if not isinstance(a, BatchAxis) 2401 and isinstance(a.size, (int, ParameterizedSize)) 2402 }, 2403 **input_size_refs, 2404 } 2405 for a, ax in enumerate(ipt.axes): 2406 cls._validate_axis( 2407 "inputs", 2408 i=i, 2409 tensor_id=ipt.id, 2410 a=a, 2411 axis=ax, 2412 valid_independent_refs=valid_independent_refs, 2413 ) 2414 return inputs 2415 2416 @staticmethod 2417 def _validate_axis( 2418 field_name: str, 2419 i: int, 2420 tensor_id: TensorId, 2421 a: int, 2422 axis: AnyAxis, 2423 valid_independent_refs: Dict[ 2424 Tuple[TensorId, AxisId], 2425 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2426 ], 2427 ): 2428 if isinstance(axis, BatchAxis) or isinstance( 2429 axis.size, (int, ParameterizedSize, DataDependentSize) 2430 ): 2431 return 2432 elif not isinstance(axis.size, SizeReference): 2433 assert_never(axis.size) 2434 2435 # validate axis.size SizeReference 2436 ref = (axis.size.tensor_id, axis.size.axis_id) 2437 if ref not in valid_independent_refs: 2438 raise ValueError( 2439 "Invalid tensor axis reference at" 2440 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2441 ) 2442 if ref == (tensor_id, axis.id): 2443 raise ValueError( 2444 "Self-referencing not allowed for" 2445 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2446 ) 2447 if axis.type == "channel": 2448 if valid_independent_refs[ref][1].type != "channel": 2449 raise ValueError( 2450 "A channel axis' size may only reference another fixed size" 2451 + " channel axis." 2452 ) 2453 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2454 ref_size = valid_independent_refs[ref][2] 2455 assert isinstance(ref_size, int), ( 2456 "channel axis ref (another channel axis) has to specify fixed" 2457 + " size" 2458 ) 2459 generated_channel_names = [ 2460 Identifier(axis.channel_names.format(i=i)) 2461 for i in range(1, ref_size + 1) 2462 ] 2463 axis.channel_names = generated_channel_names 2464 2465 if (ax_unit := getattr(axis, "unit", None)) != ( 2466 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2467 ): 2468 raise ValueError( 2469 "The units of an axis and its reference axis need to match, but" 2470 + f" '{ax_unit}' != '{ref_unit}'." 2471 ) 2472 ref_axis = valid_independent_refs[ref][1] 2473 if isinstance(ref_axis, BatchAxis): 2474 raise ValueError( 2475 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2476 + " (a batch axis is not allowed as reference)." 2477 ) 2478 2479 if isinstance(axis, WithHalo): 2480 min_size = axis.size.get_size(axis, ref_axis, n=0) 2481 if (min_size - 2 * axis.halo) < 1: 2482 raise ValueError( 2483 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2484 + f" {axis.halo}." 2485 ) 2486 2487 input_halo = axis.halo * axis.scale / ref_axis.scale 2488 if input_halo != int(input_halo) or input_halo % 2 == 1: 2489 raise ValueError( 2490 f"input_halo {input_halo} (output_halo {axis.halo} *" 2491 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2492 + f" is not an even integer for {tensor_id}.{axis.id}." 2493 ) 2494 2495 @model_validator(mode="after") 2496 def _validate_test_tensors(self) -> Self: 2497 if not validation_context_var.get().perform_io_checks: 2498 return self 2499 2500 test_arrays = [ 2501 load_array(descr.test_tensor.download().path) 2502 for descr in chain(self.inputs, self.outputs) 2503 ] 2504 tensors = { 2505 descr.id: (descr, array) 2506 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays) 2507 } 2508 validate_tensors(tensors, tensor_origin="test_tensor") 2509 return self 2510 2511 @model_validator(mode="after") 2512 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2513 ipt_refs = {t.id for t in self.inputs} 2514 out_refs = {t.id for t in self.outputs} 2515 for ipt in self.inputs: 2516 for p in ipt.preprocessing: 2517 ref = p.kwargs.get("reference_tensor") 2518 if ref is None: 2519 continue 2520 if ref not in ipt_refs: 2521 raise ValueError( 2522 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2523 + f" references are: {ipt_refs}." 2524 ) 2525 2526 for out in self.outputs: 2527 for p in out.postprocessing: 2528 ref = p.kwargs.get("reference_tensor") 2529 if ref is None: 2530 continue 2531 2532 if ref not in ipt_refs and ref not in out_refs: 2533 raise ValueError( 2534 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2535 + f" are: {ipt_refs | out_refs}." 2536 ) 2537 2538 return self 2539 2540 # TODO: use validate funcs in validate_test_tensors 2541 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2542 2543 name: Annotated[ 2544 Annotated[ 2545 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()") 2546 ], 2547 MinLen(5), 2548 MaxLen(128), 2549 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2550 ] 2551 """A human-readable name of this model. 2552 It should be no longer than 64 characters 2553 and may only contain letter, number, underscore, minus, parentheses and spaces. 2554 We recommend to chose a name that refers to the model's task and image modality. 2555 """ 2556 2557 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2558 """Describes the output tensors.""" 2559 2560 @field_validator("outputs", mode="after") 2561 @classmethod 2562 def _validate_tensor_ids( 2563 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2564 ) -> Sequence[OutputTensorDescr]: 2565 tensor_ids = [ 2566 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2567 ] 2568 duplicate_tensor_ids: List[str] = [] 2569 seen: Set[str] = set() 2570 for t in tensor_ids: 2571 if t in seen: 2572 duplicate_tensor_ids.append(t) 2573 2574 seen.add(t) 2575 2576 if duplicate_tensor_ids: 2577 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2578 2579 return outputs 2580 2581 @staticmethod 2582 def _get_axes_with_parameterized_size( 2583 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2584 ): 2585 return { 2586 f"{t.id}.{a.id}": (t, a, a.size) 2587 for t in io 2588 for a in t.axes 2589 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2590 } 2591 2592 @staticmethod 2593 def _get_axes_with_independent_size( 2594 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2595 ): 2596 return { 2597 (t.id, a.id): (t, a, a.size) 2598 for t in io 2599 for a in t.axes 2600 if not isinstance(a, BatchAxis) 2601 and isinstance(a.size, (int, ParameterizedSize)) 2602 } 2603 2604 @field_validator("outputs", mode="after") 2605 @classmethod 2606 def _validate_output_axes( 2607 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2608 ) -> List[OutputTensorDescr]: 2609 input_size_refs = cls._get_axes_with_independent_size( 2610 info.data.get("inputs", []) 2611 ) 2612 output_size_refs = cls._get_axes_with_independent_size(outputs) 2613 2614 for i, out in enumerate(outputs): 2615 valid_independent_refs: Dict[ 2616 Tuple[TensorId, AxisId], 2617 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2618 ] = { 2619 **{ 2620 (out.id, a.id): (out, a, a.size) 2621 for a in out.axes 2622 if not isinstance(a, BatchAxis) 2623 and isinstance(a.size, (int, ParameterizedSize)) 2624 }, 2625 **input_size_refs, 2626 **output_size_refs, 2627 } 2628 for a, ax in enumerate(out.axes): 2629 cls._validate_axis( 2630 "outputs", 2631 i, 2632 out.id, 2633 a, 2634 ax, 2635 valid_independent_refs=valid_independent_refs, 2636 ) 2637 2638 return outputs 2639 2640 packaged_by: List[Author] = Field(default_factory=list) 2641 """The persons that have packaged and uploaded this model. 2642 Only required if those persons differ from the `authors`.""" 2643 2644 parent: Optional[LinkedModel] = None 2645 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2646 2647 # todo: add parent self check once we have `id` 2648 # @model_validator(mode="after") 2649 # def validate_parent_is_not_self(self) -> Self: 2650 # if self.parent is not None and self.parent == self.id: 2651 # raise ValueError("The model may not reference itself as parent model") 2652 2653 # return self 2654 2655 run_mode: Annotated[ 2656 Optional[RunMode], 2657 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2658 ] = None 2659 """Custom run mode for this model: for more complex prediction procedures like test time 2660 data augmentation that currently cannot be expressed in the specification. 2661 No standard run modes are defined yet.""" 2662 2663 timestamp: Datetime = Datetime(datetime.now()) 2664 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2665 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2666 (In Python a datetime object is valid, too).""" 2667 2668 training_data: Annotated[ 2669 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2670 Field(union_mode="left_to_right"), 2671 ] = None 2672 """The dataset used to train this model""" 2673 2674 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2675 """The weights for this model. 2676 Weights can be given for different formats, but should otherwise be equivalent. 2677 The available weight formats determine which consumers can use this model.""" 2678 2679 @model_validator(mode="after") 2680 def _add_default_cover(self) -> Self: 2681 if not validation_context_var.get().perform_io_checks or self.covers: 2682 return self 2683 2684 try: 2685 generated_covers = generate_covers( 2686 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2687 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2688 ) 2689 except Exception as e: 2690 issue_warning( 2691 "Failed to generate cover image(s): {e}", 2692 value=self.covers, 2693 msg_context=dict(e=e), 2694 field="covers", 2695 ) 2696 else: 2697 self.covers.extend(generated_covers) 2698 2699 return self 2700 2701 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2702 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2703 assert all(isinstance(d, np.ndarray) for d in data) 2704 return data 2705 2706 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2707 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2708 assert all(isinstance(d, np.ndarray) for d in data) 2709 return data 2710 2711 @staticmethod 2712 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2713 batch_size = 1 2714 tensor_with_batchsize: Optional[TensorId] = None 2715 for tid in tensor_sizes: 2716 for aid, s in tensor_sizes[tid].items(): 2717 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2718 continue 2719 2720 if batch_size != 1: 2721 assert tensor_with_batchsize is not None 2722 raise ValueError( 2723 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2724 ) 2725 2726 batch_size = s 2727 tensor_with_batchsize = tid 2728 2729 return batch_size 2730 2731 def get_output_tensor_sizes( 2732 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2733 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2734 """Returns the tensor output sizes for given **input_sizes**. 2735 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2736 Otherwise it might be larger than the actual (valid) output""" 2737 batch_size = self.get_batch_size(input_sizes) 2738 ns = self.get_ns(input_sizes) 2739 2740 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2741 return tensor_sizes.outputs 2742 2743 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2744 """get parameter `n` for each parameterized axis 2745 such that the valid input size is >= the given input size""" 2746 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2747 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2748 for tid in input_sizes: 2749 for aid, s in input_sizes[tid].items(): 2750 size_descr = axes[tid][aid].size 2751 if isinstance(size_descr, ParameterizedSize): 2752 ret[(tid, aid)] = size_descr.get_n(s) 2753 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2754 pass 2755 else: 2756 assert_never(size_descr) 2757 2758 return ret 2759 2760 def get_tensor_sizes( 2761 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2762 ) -> _TensorSizes: 2763 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2764 return _TensorSizes( 2765 { 2766 t: { 2767 aa: axis_sizes.inputs[(tt, aa)] 2768 for tt, aa in axis_sizes.inputs 2769 if tt == t 2770 } 2771 for t in {tt for tt, _ in axis_sizes.inputs} 2772 }, 2773 { 2774 t: { 2775 aa: axis_sizes.outputs[(tt, aa)] 2776 for tt, aa in axis_sizes.outputs 2777 if tt == t 2778 } 2779 for t in {tt for tt, _ in axis_sizes.outputs} 2780 }, 2781 ) 2782 2783 def get_axis_sizes( 2784 self, 2785 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2786 batch_size: Optional[int] = None, 2787 *, 2788 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2789 ) -> _AxisSizes: 2790 """Determine input and output block shape for scale factors **ns** 2791 of parameterized input sizes. 2792 2793 Args: 2794 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2795 that is parameterized as `size = min + n * step`. 2796 batch_size: The desired size of the batch dimension. 2797 If given **batch_size** overwrites any batch size present in 2798 **max_input_shape**. Default 1. 2799 max_input_shape: Limits the derived block shapes. 2800 Each axis for which the input size, parameterized by `n`, is larger 2801 than **max_input_shape** is set to the minimal value `n_min` for which 2802 this is still true. 2803 Use this for small input samples or large values of **ns**. 2804 Or simply whenever you know the full input shape. 2805 2806 Returns: 2807 Resolved axis sizes for model inputs and outputs. 2808 """ 2809 max_input_shape = max_input_shape or {} 2810 if batch_size is None: 2811 for (_t_id, a_id), s in max_input_shape.items(): 2812 if a_id == BATCH_AXIS_ID: 2813 batch_size = s 2814 break 2815 else: 2816 batch_size = 1 2817 2818 all_axes = { 2819 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2820 } 2821 2822 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2823 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2824 2825 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2826 if isinstance(a, BatchAxis): 2827 if (t_descr.id, a.id) in ns: 2828 logger.warning( 2829 "Ignoring unexpected size increment factor (n) for batch axis" 2830 + " of tensor '{}'.", 2831 t_descr.id, 2832 ) 2833 return batch_size 2834 elif isinstance(a.size, int): 2835 if (t_descr.id, a.id) in ns: 2836 logger.warning( 2837 "Ignoring unexpected size increment factor (n) for fixed size" 2838 + " axis '{}' of tensor '{}'.", 2839 a.id, 2840 t_descr.id, 2841 ) 2842 return a.size 2843 elif isinstance(a.size, ParameterizedSize): 2844 if (t_descr.id, a.id) not in ns: 2845 raise ValueError( 2846 "Size increment factor (n) missing for parametrized axis" 2847 + f" '{a.id}' of tensor '{t_descr.id}'." 2848 ) 2849 n = ns[(t_descr.id, a.id)] 2850 s_max = max_input_shape.get((t_descr.id, a.id)) 2851 if s_max is not None: 2852 n = min(n, a.size.get_n(s_max)) 2853 2854 return a.size.get_size(n) 2855 2856 elif isinstance(a.size, SizeReference): 2857 if (t_descr.id, a.id) in ns: 2858 logger.warning( 2859 "Ignoring unexpected size increment factor (n) for axis '{}'" 2860 + " of tensor '{}' with size reference.", 2861 a.id, 2862 t_descr.id, 2863 ) 2864 assert not isinstance(a, BatchAxis) 2865 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2866 assert not isinstance(ref_axis, BatchAxis) 2867 ref_key = (a.size.tensor_id, a.size.axis_id) 2868 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2869 assert ref_size is not None, ref_key 2870 assert not isinstance(ref_size, _DataDepSize), ref_key 2871 return a.size.get_size( 2872 axis=a, 2873 ref_axis=ref_axis, 2874 ref_size=ref_size, 2875 ) 2876 elif isinstance(a.size, DataDependentSize): 2877 if (t_descr.id, a.id) in ns: 2878 logger.warning( 2879 "Ignoring unexpected increment factor (n) for data dependent" 2880 + " size axis '{}' of tensor '{}'.", 2881 a.id, 2882 t_descr.id, 2883 ) 2884 return _DataDepSize(a.size.min, a.size.max) 2885 else: 2886 assert_never(a.size) 2887 2888 # first resolve all , but the `SizeReference` input sizes 2889 for t_descr in self.inputs: 2890 for a in t_descr.axes: 2891 if not isinstance(a.size, SizeReference): 2892 s = get_axis_size(a) 2893 assert not isinstance(s, _DataDepSize) 2894 inputs[t_descr.id, a.id] = s 2895 2896 # resolve all other input axis sizes 2897 for t_descr in self.inputs: 2898 for a in t_descr.axes: 2899 if isinstance(a.size, SizeReference): 2900 s = get_axis_size(a) 2901 assert not isinstance(s, _DataDepSize) 2902 inputs[t_descr.id, a.id] = s 2903 2904 # resolve all output axis sizes 2905 for t_descr in self.outputs: 2906 for a in t_descr.axes: 2907 assert not isinstance(a.size, ParameterizedSize) 2908 s = get_axis_size(a) 2909 outputs[t_descr.id, a.id] = s 2910 2911 return _AxisSizes(inputs=inputs, outputs=outputs) 2912 2913 @model_validator(mode="before") 2914 @classmethod 2915 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 2916 if ( 2917 data.get("type") == "model" 2918 and isinstance(fv := data.get("format_version"), str) 2919 and fv.count(".") == 2 2920 ): 2921 fv_parts = fv.split(".") 2922 if any(not p.isdigit() for p in fv_parts): 2923 return data 2924 2925 fv_tuple = tuple(map(int, fv_parts)) 2926 2927 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 2928 if fv_tuple[:2] in ((0, 3), (0, 4)): 2929 m04 = _ModelDescr_v0_4.load(data) 2930 if not isinstance(m04, InvalidDescr): 2931 return _model_conv.convert_as_dict(m04) 2932 elif fv_tuple[:2] == (0, 5): 2933 # bump patch version 2934 data["format_version"] = cls.implemented_format_version 2935 2936 return data 2937 2938 2939class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 2940 def _convert( 2941 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 2942 ) -> "ModelDescr | dict[str, Any]": 2943 name = "".join( 2944 c if c in string.ascii_letters + string.digits + "_- ()" else " " 2945 for c in src.name 2946 ) 2947 2948 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 2949 conv = ( 2950 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 2951 ) 2952 return None if auths is None else [conv(a) for a in auths] 2953 2954 if TYPE_CHECKING: 2955 arch_file_conv = _arch_file_conv.convert 2956 arch_lib_conv = _arch_lib_conv.convert 2957 else: 2958 arch_file_conv = _arch_file_conv.convert_as_dict 2959 arch_lib_conv = _arch_lib_conv.convert_as_dict 2960 2961 input_size_refs = { 2962 ipt.name: { 2963 a: s 2964 for a, s in zip( 2965 ipt.axes, 2966 ( 2967 ipt.shape.min 2968 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 2969 else ipt.shape 2970 ), 2971 ) 2972 } 2973 for ipt in src.inputs 2974 if ipt.shape 2975 } 2976 output_size_refs = { 2977 **{ 2978 out.name: {a: s for a, s in zip(out.axes, out.shape)} 2979 for out in src.outputs 2980 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 2981 }, 2982 **input_size_refs, 2983 } 2984 2985 return tgt( 2986 attachments=( 2987 [] 2988 if src.attachments is None 2989 else [FileDescr(source=f) for f in src.attachments.files] 2990 ), 2991 authors=[ 2992 _author_conv.convert_as_dict(a) for a in src.authors 2993 ], # pyright: ignore[reportArgumentType] 2994 cite=[ 2995 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 2996 ], # pyright: ignore[reportArgumentType] 2997 config=src.config, 2998 covers=src.covers, 2999 description=src.description, 3000 documentation=src.documentation, 3001 format_version="0.5.3", 3002 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3003 icon=src.icon, 3004 id=None if src.id is None else ModelId(src.id), 3005 id_emoji=src.id_emoji, 3006 license=src.license, # type: ignore 3007 links=src.links, 3008 maintainers=[ 3009 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 3010 ], # pyright: ignore[reportArgumentType] 3011 name=name, 3012 tags=src.tags, 3013 type=src.type, 3014 uploader=src.uploader, 3015 version=src.version, 3016 inputs=[ # pyright: ignore[reportArgumentType] 3017 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3018 for ipt, tt, st, in zip( 3019 src.inputs, 3020 src.test_inputs, 3021 src.sample_inputs or [None] * len(src.test_inputs), 3022 ) 3023 ], 3024 outputs=[ # pyright: ignore[reportArgumentType] 3025 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3026 for out, tt, st, in zip( 3027 src.outputs, 3028 src.test_outputs, 3029 src.sample_outputs or [None] * len(src.test_outputs), 3030 ) 3031 ], 3032 parent=( 3033 None 3034 if src.parent is None 3035 else LinkedModel( 3036 id=ModelId( 3037 str(src.parent.id) 3038 + ( 3039 "" 3040 if src.parent.version_number is None 3041 else f"/{src.parent.version_number}" 3042 ) 3043 ) 3044 ) 3045 ), 3046 training_data=( 3047 None 3048 if src.training_data is None 3049 else ( 3050 LinkedDataset( 3051 id=DatasetId( 3052 str(src.training_data.id) 3053 + ( 3054 "" 3055 if src.training_data.version_number is None 3056 else f"/{src.training_data.version_number}" 3057 ) 3058 ) 3059 ) 3060 if isinstance(src.training_data, LinkedDataset02) 3061 else src.training_data 3062 ) 3063 ), 3064 packaged_by=[ 3065 _author_conv.convert_as_dict(a) for a in src.packaged_by 3066 ], # pyright: ignore[reportArgumentType] 3067 run_mode=src.run_mode, 3068 timestamp=src.timestamp, 3069 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3070 keras_hdf5=(w := src.weights.keras_hdf5) 3071 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3072 authors=conv_authors(w.authors), 3073 source=w.source, 3074 tensorflow_version=w.tensorflow_version or Version("1.15"), 3075 parent=w.parent, 3076 ), 3077 onnx=(w := src.weights.onnx) 3078 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3079 source=w.source, 3080 authors=conv_authors(w.authors), 3081 parent=w.parent, 3082 opset_version=w.opset_version or 15, 3083 ), 3084 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3085 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3086 source=w.source, 3087 authors=conv_authors(w.authors), 3088 parent=w.parent, 3089 architecture=( 3090 arch_file_conv( 3091 w.architecture, 3092 w.architecture_sha256, 3093 w.kwargs, 3094 ) 3095 if isinstance(w.architecture, _CallableFromFile_v0_4) 3096 else arch_lib_conv(w.architecture, w.kwargs) 3097 ), 3098 pytorch_version=w.pytorch_version or Version("1.10"), 3099 dependencies=( 3100 None 3101 if w.dependencies is None 3102 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 3103 source=cast( 3104 ImportantFileSource, 3105 str(deps := w.dependencies)[ 3106 ( 3107 len("conda:") 3108 if str(deps).startswith("conda:") 3109 else 0 3110 ) : 3111 ], 3112 ) 3113 ) 3114 ), 3115 ), 3116 tensorflow_js=(w := src.weights.tensorflow_js) 3117 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3118 source=w.source, 3119 authors=conv_authors(w.authors), 3120 parent=w.parent, 3121 tensorflow_version=w.tensorflow_version or Version("1.15"), 3122 ), 3123 tensorflow_saved_model_bundle=( 3124 w := src.weights.tensorflow_saved_model_bundle 3125 ) 3126 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3127 authors=conv_authors(w.authors), 3128 parent=w.parent, 3129 source=w.source, 3130 tensorflow_version=w.tensorflow_version or Version("1.15"), 3131 dependencies=( 3132 None 3133 if w.dependencies is None 3134 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 3135 source=cast( 3136 ImportantFileSource, 3137 ( 3138 str(w.dependencies)[len("conda:") :] 3139 if str(w.dependencies).startswith("conda:") 3140 else str(w.dependencies) 3141 ), 3142 ) 3143 ) 3144 ), 3145 ), 3146 torchscript=(w := src.weights.torchscript) 3147 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3148 source=w.source, 3149 authors=conv_authors(w.authors), 3150 parent=w.parent, 3151 pytorch_version=w.pytorch_version or Version("1.10"), 3152 ), 3153 ), 3154 ) 3155 3156 3157_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3158 3159 3160# create better cover images for 3d data and non-image outputs 3161def generate_covers( 3162 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3163 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3164) -> List[Path]: 3165 def squeeze( 3166 data: NDArray[Any], axes: Sequence[AnyAxis] 3167 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3168 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3169 if data.ndim != len(axes): 3170 raise ValueError( 3171 f"tensor shape {data.shape} does not match described axes" 3172 + f" {[a.id for a in axes]}" 3173 ) 3174 3175 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3176 return data.squeeze(), axes 3177 3178 def normalize( 3179 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3180 ) -> NDArray[np.float32]: 3181 data = data.astype("float32") 3182 data -= data.min(axis=axis, keepdims=True) 3183 data /= data.max(axis=axis, keepdims=True) + eps 3184 return data 3185 3186 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3187 original_shape = data.shape 3188 data, axes = squeeze(data, axes) 3189 3190 # take slice fom any batch or index axis if needed 3191 # and convert the first channel axis and take a slice from any additional channel axes 3192 slices: Tuple[slice, ...] = () 3193 ndim = data.ndim 3194 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3195 has_c_axis = False 3196 for i, a in enumerate(axes): 3197 s = data.shape[i] 3198 assert s > 1 3199 if ( 3200 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3201 and ndim > ndim_need 3202 ): 3203 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3204 ndim -= 1 3205 elif isinstance(a, ChannelAxis): 3206 if has_c_axis: 3207 # second channel axis 3208 data = data[slices + (slice(0, 1),)] 3209 ndim -= 1 3210 else: 3211 has_c_axis = True 3212 if s == 2: 3213 # visualize two channels with cyan and magenta 3214 data = np.concatenate( 3215 [ 3216 data[slices + (slice(1, 2),)], 3217 data[slices + (slice(0, 1),)], 3218 ( 3219 data[slices + (slice(0, 1),)] 3220 + data[slices + (slice(1, 2),)] 3221 ) 3222 / 2, # TODO: take maximum instead? 3223 ], 3224 axis=i, 3225 ) 3226 elif data.shape[i] == 3: 3227 pass # visualize 3 channels as RGB 3228 else: 3229 # visualize first 3 channels as RGB 3230 data = data[slices + (slice(3),)] 3231 3232 assert data.shape[i] == 3 3233 3234 slices += (slice(None),) 3235 3236 data, axes = squeeze(data, axes) 3237 assert len(axes) == ndim 3238 # take slice from z axis if needed 3239 slices = () 3240 if ndim > ndim_need: 3241 for i, a in enumerate(axes): 3242 s = data.shape[i] 3243 if a.id == AxisId("z"): 3244 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3245 data, axes = squeeze(data, axes) 3246 ndim -= 1 3247 break 3248 3249 slices += (slice(None),) 3250 3251 # take slice from any space or time axis 3252 slices = () 3253 3254 for i, a in enumerate(axes): 3255 if ndim <= ndim_need: 3256 break 3257 3258 s = data.shape[i] 3259 assert s > 1 3260 if isinstance( 3261 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3262 ): 3263 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3264 ndim -= 1 3265 3266 slices += (slice(None),) 3267 3268 del slices 3269 data, axes = squeeze(data, axes) 3270 assert len(axes) == ndim 3271 3272 if (has_c_axis and ndim != 3) or ndim != 2: 3273 raise ValueError( 3274 f"Failed to construct cover image from shape {original_shape}" 3275 ) 3276 3277 if not has_c_axis: 3278 assert ndim == 2 3279 data = np.repeat(data[:, :, None], 3, axis=2) 3280 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3281 ndim += 1 3282 3283 assert ndim == 3 3284 3285 # transpose axis order such that longest axis comes first... 3286 axis_order = list(np.argsort(list(data.shape))) 3287 axis_order.reverse() 3288 # ... and channel axis is last 3289 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3290 axis_order.append(axis_order.pop(c)) 3291 axes = [axes[ao] for ao in axis_order] 3292 data = data.transpose(axis_order) 3293 3294 # h, w = data.shape[:2] 3295 # if h / w in (1.0 or 2.0): 3296 # pass 3297 # elif h / w < 2: 3298 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3299 3300 norm_along = ( 3301 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3302 ) 3303 # normalize the data and map to 8 bit 3304 data = normalize(data, norm_along) 3305 data = (data * 255).astype("uint8") 3306 3307 return data 3308 3309 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3310 assert im0.dtype == im1.dtype == np.uint8 3311 assert im0.shape == im1.shape 3312 assert im0.ndim == 3 3313 N, M, C = im0.shape 3314 assert C == 3 3315 out = np.ones((N, M, C), dtype="uint8") 3316 for c in range(C): 3317 outc = np.tril(im0[..., c]) 3318 mask = outc == 0 3319 outc[mask] = np.triu(im1[..., c])[mask] 3320 out[..., c] = outc 3321 3322 return out 3323 3324 ipt_descr, ipt = inputs[0] 3325 out_descr, out = outputs[0] 3326 3327 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3328 out_img = to_2d_image(out, out_descr.axes) 3329 3330 cover_folder = Path(mkdtemp()) 3331 if ipt_img.shape == out_img.shape: 3332 covers = [cover_folder / "cover.png"] 3333 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3334 else: 3335 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3336 imwrite(covers[0], ipt_img) 3337 imwrite(covers[1], out_img) 3338 3339 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
194class TensorId(LowerCaseIdentifier): 195 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 196 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 197 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
200class AxisId(LowerCaseIdentifier): 201 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 202 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 203 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
244class ParameterizedSize(Node): 245 """Describes a range of valid tensor axis sizes as `size = min + n*step`.""" 246 247 N: ClassVar[Type[int]] = ParameterizedSize_N 248 """integer to parameterize this axis""" 249 250 min: Annotated[int, Gt(0)] 251 step: Annotated[int, Gt(0)] 252 253 def validate_size(self, size: int) -> int: 254 if size < self.min: 255 raise ValueError(f"size {size} < {self.min}") 256 if (size - self.min) % self.step != 0: 257 raise ValueError( 258 f"axis of size {size} is not parameterized by `min + n*step` =" 259 + f" `{self.min} + n*{self.step}`" 260 ) 261 262 return size 263 264 def get_size(self, n: ParameterizedSize_N) -> int: 265 return self.min + self.step * n 266 267 def get_n(self, s: int) -> ParameterizedSize_N: 268 """return smallest n parameterizing a size greater or equal than `s`""" 269 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
253 def validate_size(self, size: int) -> int: 254 if size < self.min: 255 raise ValueError(f"size {size} < {self.min}") 256 if (size - self.min) % self.step != 0: 257 raise ValueError( 258 f"axis of size {size} is not parameterized by `min + n*step` =" 259 + f" `{self.min} + n*{self.step}`" 260 ) 261 262 return size
267 def get_n(self, s: int) -> ParameterizedSize_N: 268 """return smallest n parameterizing a size greater or equal than `s`""" 269 return ceil((s - self.min) / self.step)
return smallest n parameterizing a size greater or equal than s
Inherited Members
272class DataDependentSize(Node): 273 min: Annotated[int, Gt(0)] = 1 274 max: Annotated[Optional[int], Gt(1)] = None 275 276 @model_validator(mode="after") 277 def _validate_max_gt_min(self): 278 if self.max is not None and self.min >= self.max: 279 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 280 281 return self 282 283 def validate_size(self, size: int) -> int: 284 if size < self.min: 285 raise ValueError(f"size {size} < {self.min}") 286 287 if self.max is not None and size > self.max: 288 raise ValueError(f"size {size} > {self.max}") 289 290 return size
Inherited Members
293class SizeReference(Node): 294 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 295 296 `axis.size = reference.size * reference.scale / axis.scale + offset` 297 298 Note: 299 1. The axis and the referenced axis need to have the same unit (or no unit). 300 2. Batch axes may not be referenced. 301 3. Fractions are rounded down. 302 4. If the reference axis is `concatenable` the referencing axis is assumed to be 303 `concatenable` as well with the same block order. 304 305 Example: 306 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 307 Let's assume that we want to express the image height h in relation to its width w 308 instead of only accepting input images of exactly 100*49 pixels 309 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 310 311 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 312 >>> h = SpaceInputAxis( 313 ... id=AxisId("h"), 314 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 315 ... unit="millimeter", 316 ... scale=4, 317 ... ) 318 >>> print(h.size.get_size(h, w)) 319 49 320 321 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 322 """ 323 324 tensor_id: TensorId 325 """tensor id of the reference axis""" 326 327 axis_id: AxisId 328 """axis id of the reference axis""" 329 330 offset: int = 0 331 332 def get_size( 333 self, 334 axis: Union[ 335 ChannelAxis, 336 IndexInputAxis, 337 IndexOutputAxis, 338 TimeInputAxis, 339 SpaceInputAxis, 340 TimeOutputAxis, 341 TimeOutputAxisWithHalo, 342 SpaceOutputAxis, 343 SpaceOutputAxisWithHalo, 344 ], 345 ref_axis: Union[ 346 ChannelAxis, 347 IndexInputAxis, 348 IndexOutputAxis, 349 TimeInputAxis, 350 SpaceInputAxis, 351 TimeOutputAxis, 352 TimeOutputAxisWithHalo, 353 SpaceOutputAxis, 354 SpaceOutputAxisWithHalo, 355 ], 356 n: ParameterizedSize_N = 0, 357 ref_size: Optional[int] = None, 358 ): 359 """Compute the concrete size for a given axis and its reference axis. 360 361 Args: 362 axis: The axis this `SizeReference` is the size of. 363 ref_axis: The reference axis to compute the size from. 364 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 365 and no fixed **ref_size** is given, 366 **n** is used to compute the size of the parameterized **ref_axis**. 367 ref_size: Overwrite the reference size instead of deriving it from 368 **ref_axis** 369 (**ref_axis.scale** is still used; any given **n** is ignored). 370 """ 371 assert ( 372 axis.size == self 373 ), "Given `axis.size` is not defined by this `SizeReference`" 374 375 assert ( 376 ref_axis.id == self.axis_id 377 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 378 379 assert axis.unit == ref_axis.unit, ( 380 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 381 f" but {axis.unit}!={ref_axis.unit}" 382 ) 383 if ref_size is None: 384 if isinstance(ref_axis.size, (int, float)): 385 ref_size = ref_axis.size 386 elif isinstance(ref_axis.size, ParameterizedSize): 387 ref_size = ref_axis.size.get_size(n) 388 elif isinstance(ref_axis.size, DataDependentSize): 389 raise ValueError( 390 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 391 ) 392 elif isinstance(ref_axis.size, SizeReference): 393 raise ValueError( 394 "Reference axis referenced in `SizeReference` may not be sized by a" 395 + " `SizeReference` itself." 396 ) 397 else: 398 assert_never(ref_axis.size) 399 400 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 401 402 @staticmethod 403 def _get_unit( 404 axis: Union[ 405 ChannelAxis, 406 IndexInputAxis, 407 IndexOutputAxis, 408 TimeInputAxis, 409 SpaceInputAxis, 410 TimeOutputAxis, 411 TimeOutputAxisWithHalo, 412 SpaceOutputAxis, 413 SpaceOutputAxisWithHalo, 414 ], 415 ): 416 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
332 def get_size( 333 self, 334 axis: Union[ 335 ChannelAxis, 336 IndexInputAxis, 337 IndexOutputAxis, 338 TimeInputAxis, 339 SpaceInputAxis, 340 TimeOutputAxis, 341 TimeOutputAxisWithHalo, 342 SpaceOutputAxis, 343 SpaceOutputAxisWithHalo, 344 ], 345 ref_axis: Union[ 346 ChannelAxis, 347 IndexInputAxis, 348 IndexOutputAxis, 349 TimeInputAxis, 350 SpaceInputAxis, 351 TimeOutputAxis, 352 TimeOutputAxisWithHalo, 353 SpaceOutputAxis, 354 SpaceOutputAxisWithHalo, 355 ], 356 n: ParameterizedSize_N = 0, 357 ref_size: Optional[int] = None, 358 ): 359 """Compute the concrete size for a given axis and its reference axis. 360 361 Args: 362 axis: The axis this `SizeReference` is the size of. 363 ref_axis: The reference axis to compute the size from. 364 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 365 and no fixed **ref_size** is given, 366 **n** is used to compute the size of the parameterized **ref_axis**. 367 ref_size: Overwrite the reference size instead of deriving it from 368 **ref_axis** 369 (**ref_axis.scale** is still used; any given **n** is ignored). 370 """ 371 assert ( 372 axis.size == self 373 ), "Given `axis.size` is not defined by this `SizeReference`" 374 375 assert ( 376 ref_axis.id == self.axis_id 377 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 378 379 assert axis.unit == ref_axis.unit, ( 380 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 381 f" but {axis.unit}!={ref_axis.unit}" 382 ) 383 if ref_size is None: 384 if isinstance(ref_axis.size, (int, float)): 385 ref_size = ref_axis.size 386 elif isinstance(ref_axis.size, ParameterizedSize): 387 ref_size = ref_axis.size.get_size(n) 388 elif isinstance(ref_axis.size, DataDependentSize): 389 raise ValueError( 390 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 391 ) 392 elif isinstance(ref_axis.size, SizeReference): 393 raise ValueError( 394 "Reference axis referenced in `SizeReference` may not be sized by a" 395 + " `SizeReference` itself." 396 ) 397 else: 398 assert_never(ref_axis.size) 399 400 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
Inherited Members
419class AxisBase(NodeWithExplicitlySetFields): 420 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"}) 421 422 id: AxisId 423 """An axis id unique across all axes of one tensor.""" 424 425 description: Annotated[str, MaxLen(128)] = ""
428class WithHalo(Node): 429 halo: Annotated[int, Ge(1)] 430 """The halo should be cropped from the output tensor to avoid boundary effects. 431 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 432 To document a halo that is already cropped by the model use `size.offset` instead.""" 433 434 size: Annotated[ 435 SizeReference, 436 Field( 437 examples=[ 438 10, 439 SizeReference( 440 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 441 ).model_dump(mode="json"), 442 ] 443 ), 444 ] 445 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
Inherited Members
451class BatchAxis(AxisBase): 452 type: Literal["batch"] = "batch" 453 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 454 size: Optional[Literal[1]] = None 455 """The batch size may be fixed to 1, 456 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 457 458 @property 459 def scale(self): 460 return 1.0 461 462 @property 463 def concatenable(self): 464 return True 465 466 @property 467 def unit(self): 468 return None
471class ChannelAxis(AxisBase): 472 type: Literal["channel"] = "channel" 473 id: NonBatchAxisId = AxisId("channel") 474 channel_names: NotEmpty[List[Identifier]] 475 476 @property 477 def size(self) -> int: 478 return len(self.channel_names) 479 480 @property 481 def concatenable(self): 482 return False 483 484 @property 485 def scale(self) -> float: 486 return 1.0 487 488 @property 489 def unit(self): 490 return None
526class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 527 concatenable: bool = False 528 """If a model has a `concatenable` input axis, it can be processed blockwise, 529 splitting a longer sample axis into blocks matching its input tensor description. 530 Output axes are concatenable if they have a `SizeReference` to a concatenable 531 input axis. 532 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
535class IndexOutputAxis(IndexAxisBase): 536 size: Annotated[ 537 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 538 Field( 539 examples=[ 540 10, 541 SizeReference( 542 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 543 ).model_dump(mode="json"), 544 ] 545 ), 546 ] 547 """The size/length of this axis can be specified as 548 - fixed integer 549 - reference to another axis with an optional offset (`SizeReference`) 550 - data dependent size using `DataDependentSize` (size is only known after model inference) 551 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
554class TimeAxisBase(AxisBase): 555 type: Literal["time"] = "time" 556 id: NonBatchAxisId = AxisId("time") 557 unit: Optional[TimeUnit] = None 558 scale: Annotated[float, Gt(0)] = 1.0
561class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 562 concatenable: bool = False 563 """If a model has a `concatenable` input axis, it can be processed blockwise, 564 splitting a longer sample axis into blocks matching its input tensor description. 565 Output axes are concatenable if they have a `SizeReference` to a concatenable 566 input axis. 567 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
570class SpaceAxisBase(AxisBase): 571 type: Literal["space"] = "space" 572 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 573 unit: Optional[SpaceUnit] = None 574 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
577class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 578 concatenable: bool = False 579 """If a model has a `concatenable` input axis, it can be processed blockwise, 580 splitting a longer sample axis into blocks matching its input tensor description. 581 Output axes are concatenable if they have a `SizeReference` to a concatenable 582 input axis. 583 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
681class NominalOrOrdinalDataDescr(Node): 682 values: TVs 683 """A fixed set of nominal or an ascending sequence of ordinal values. 684 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'. 685 String `values` are interpreted as labels for tensor values 0, ..., N. 686 Note: as YAML 1.2 does not natively support a "set" datatype, 687 nominal values should be given as a sequence (aka list/array) as well. 688 """ 689 690 type: Annotated[ 691 NominalOrOrdinalDType, 692 Field( 693 examples=[ 694 "float32", 695 "uint8", 696 "uint16", 697 "int64", 698 "bool", 699 ], 700 ), 701 ] = "uint8" 702 703 @model_validator(mode="after") 704 def _validate_values_match_type( 705 self, 706 ) -> Self: 707 incompatible: List[Any] = [] 708 for v in self.values: 709 if self.type == "bool": 710 if not isinstance(v, bool): 711 incompatible.append(v) 712 elif self.type in DTYPE_LIMITS: 713 if ( 714 isinstance(v, (int, float)) 715 and ( 716 v < DTYPE_LIMITS[self.type].min 717 or v > DTYPE_LIMITS[self.type].max 718 ) 719 or (isinstance(v, str) and "uint" not in self.type) 720 or (isinstance(v, float) and "int" in self.type) 721 ): 722 incompatible.append(v) 723 else: 724 incompatible.append(v) 725 726 if len(incompatible) == 5: 727 incompatible.append("...") 728 break 729 730 if incompatible: 731 raise ValueError( 732 f"data type '{self.type}' incompatible with values {incompatible}" 733 ) 734 735 return self 736 737 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 738 739 @property 740 def range(self): 741 if isinstance(self.values[0], str): 742 return 0, len(self.values) - 1 743 else: 744 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data_type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
Inherited Members
761class IntervalOrRatioDataDescr(Node): 762 type: Annotated[ # todo: rename to dtype 763 IntervalOrRatioDType, 764 Field( 765 examples=["float32", "float64", "uint8", "uint16"], 766 ), 767 ] = "float32" 768 range: Tuple[Optional[float], Optional[float]] = ( 769 None, 770 None, 771 ) 772 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 773 `None` corresponds to min/max of what can be expressed by `data_type`.""" 774 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 775 scale: float = 1.0 776 """Scale for data on an interval (or ratio) scale.""" 777 offset: Optional[float] = None 778 """Offset for data on a ratio scale."""
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
None
corresponds to min/max of what can be expressed by data_type
.
Inherited Members
784class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 785 """processing base class""" 786 787 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field 788 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})
processing base class
791class BinarizeKwargs(ProcessingKwargs): 792 """key word arguments for `BinarizeDescr`""" 793 794 threshold: float 795 """The fixed threshold"""
key word arguments for BinarizeDescr
798class BinarizeAlongAxisKwargs(ProcessingKwargs): 799 """key word arguments for `BinarizeDescr`""" 800 801 threshold: NotEmpty[List[float]] 802 """The fixed threshold values along `axis`""" 803 804 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 805 """The `threshold` axis"""
key word arguments for BinarizeDescr
808class BinarizeDescr(ProcessingDescrBase): 809 """Binarize the tensor with a fixed threshold. 810 811 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 812 will be set to one, values below the threshold to zero. 813 814 Examples: 815 - in YAML 816 ```yaml 817 postprocessing: 818 - id: binarize 819 kwargs: 820 axis: 'channel' 821 threshold: [0.25, 0.5, 0.75] 822 ``` 823 - in Python: 824 >>> postprocessing = [BinarizeDescr( 825 ... kwargs=BinarizeAlongAxisKwargs( 826 ... axis=AxisId('channel'), 827 ... threshold=[0.25, 0.5, 0.75], 828 ... ) 829 ... )] 830 """ 831 832 id: Literal["binarize"] = "binarize" 833 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
836class ClipDescr(ProcessingDescrBase): 837 """Set tensor values below min to min and above max to max. 838 839 See `ScaleRangeDescr` for examples. 840 """ 841 842 id: Literal["clip"] = "clip" 843 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
846class EnsureDtypeKwargs(ProcessingKwargs): 847 """key word arguments for `EnsureDtypeDescr`""" 848 849 dtype: Literal[ 850 "float32", 851 "float64", 852 "uint8", 853 "int8", 854 "uint16", 855 "int16", 856 "uint32", 857 "int32", 858 "uint64", 859 "int64", 860 "bool", 861 ]
key word arguments for EnsureDtypeDescr
864class EnsureDtypeDescr(ProcessingDescrBase): 865 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 866 867 This can for example be used to ensure the inner neural network model gets a 868 different input tensor data type than the fully described bioimage.io model does. 869 870 Examples: 871 The described bioimage.io model (incl. preprocessing) accepts any 872 float32-compatible tensor, normalizes it with percentiles and clipping and then 873 casts it to uint8, which is what the neural network in this example expects. 874 - in YAML 875 ```yaml 876 inputs: 877 - data: 878 type: float32 # described bioimage.io model is compatible with any float32 input tensor 879 preprocessing: 880 - id: scale_range 881 kwargs: 882 axes: ['y', 'x'] 883 max_percentile: 99.8 884 min_percentile: 5.0 885 - id: clip 886 kwargs: 887 min: 0.0 888 max: 1.0 889 - id: ensure_dtype 890 kwargs: 891 dtype: uint8 892 ``` 893 - in Python: 894 >>> preprocessing = [ 895 ... ScaleRangeDescr( 896 ... kwargs=ScaleRangeKwargs( 897 ... axes= (AxisId('y'), AxisId('x')), 898 ... max_percentile= 99.8, 899 ... min_percentile= 5.0, 900 ... ) 901 ... ), 902 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 903 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 904 ... ] 905 """ 906 907 id: Literal["ensure_dtype"] = "ensure_dtype" 908 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
911class ScaleLinearKwargs(ProcessingKwargs): 912 """Key word arguments for `ScaleLinearDescr`""" 913 914 gain: float = 1.0 915 """multiplicative factor""" 916 917 offset: float = 0.0 918 """additive term""" 919 920 @model_validator(mode="after") 921 def _validate(self) -> Self: 922 if self.gain == 1.0 and self.offset == 0.0: 923 raise ValueError( 924 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 925 + " != 0.0." 926 ) 927 928 return self
Key word arguments for ScaleLinearDescr
931class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 932 """Key word arguments for `ScaleLinearDescr`""" 933 934 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 935 """The axis of of gains/offsets values.""" 936 937 gain: Union[float, NotEmpty[List[float]]] = 1.0 938 """multiplicative factor""" 939 940 offset: Union[float, NotEmpty[List[float]]] = 0.0 941 """additive term""" 942 943 @model_validator(mode="after") 944 def _validate(self) -> Self: 945 946 if isinstance(self.gain, list): 947 if isinstance(self.offset, list): 948 if len(self.gain) != len(self.offset): 949 raise ValueError( 950 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 951 ) 952 else: 953 self.offset = [float(self.offset)] * len(self.gain) 954 elif isinstance(self.offset, list): 955 self.gain = [float(self.gain)] * len(self.offset) 956 else: 957 raise ValueError( 958 "Do not specify an `axis` for scalar gain and offset values." 959 ) 960 961 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 962 raise ValueError( 963 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 964 + " != 0.0." 965 ) 966 967 return self
Key word arguments for ScaleLinearDescr
The axis of of gains/offsets values.
970class ScaleLinearDescr(ProcessingDescrBase): 971 """Fixed linear scaling. 972 973 Examples: 974 1. Scale with scalar gain and offset 975 - in YAML 976 ```yaml 977 preprocessing: 978 - id: scale_linear 979 kwargs: 980 gain: 2.0 981 offset: 3.0 982 ``` 983 - in Python: 984 >>> preprocessing = [ 985 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 986 ... ] 987 988 2. Independent scaling along an axis 989 - in YAML 990 ```yaml 991 preprocessing: 992 - id: scale_linear 993 kwargs: 994 axis: 'channel' 995 gain: [1.0, 2.0, 3.0] 996 ``` 997 - in Python: 998 >>> preprocessing = [ 999 ... ScaleLinearDescr( 1000 ... kwargs=ScaleLinearAlongAxisKwargs( 1001 ... axis=AxisId("channel"), 1002 ... gain=[1.0, 2.0, 3.0], 1003 ... ) 1004 ... ) 1005 ... ] 1006 1007 """ 1008 1009 id: Literal["scale_linear"] = "scale_linear" 1010 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
1013class SigmoidDescr(ProcessingDescrBase): 1014 """The logistic sigmoid funciton, a.k.a. expit function. 1015 1016 Examples: 1017 - in YAML 1018 ```yaml 1019 postprocessing: 1020 - id: sigmoid 1021 ``` 1022 - in Python: 1023 >>> postprocessing = [SigmoidDescr()] 1024 """ 1025 1026 id: Literal["sigmoid"] = "sigmoid" 1027 1028 @property 1029 def kwargs(self) -> ProcessingKwargs: 1030 """empty kwargs""" 1031 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1034class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1035 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1036 1037 mean: float 1038 """The mean value to normalize with.""" 1039 1040 std: Annotated[float, Ge(1e-6)] 1041 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
1044class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1045 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1046 1047 mean: NotEmpty[List[float]] 1048 """The mean value(s) to normalize with.""" 1049 1050 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1051 """The standard deviation value(s) to normalize with. 1052 Size must match `mean` values.""" 1053 1054 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1055 """The axis of the mean/std values to normalize each entry along that dimension 1056 separately.""" 1057 1058 @model_validator(mode="after") 1059 def _mean_and_std_match(self) -> Self: 1060 if len(self.mean) != len(self.std): 1061 raise ValueError( 1062 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1063 + " must match." 1064 ) 1065 1066 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
1069class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1070 """Subtract a given mean and divide by the standard deviation. 1071 1072 Normalize with fixed, precomputed values for 1073 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1074 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1075 axes. 1076 1077 Examples: 1078 1. scalar value for whole tensor 1079 - in YAML 1080 ```yaml 1081 preprocessing: 1082 - id: fixed_zero_mean_unit_variance 1083 kwargs: 1084 mean: 103.5 1085 std: 13.7 1086 ``` 1087 - in Python 1088 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1089 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1090 ... )] 1091 1092 2. independently along an axis 1093 - in YAML 1094 ```yaml 1095 preprocessing: 1096 - id: fixed_zero_mean_unit_variance 1097 kwargs: 1098 axis: channel 1099 mean: [101.5, 102.5, 103.5] 1100 std: [11.7, 12.7, 13.7] 1101 ``` 1102 - in Python 1103 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1104 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1105 ... axis=AxisId("channel"), 1106 ... mean=[101.5, 102.5, 103.5], 1107 ... std=[11.7, 12.7, 13.7], 1108 ... ) 1109 ... )] 1110 """ 1111 1112 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1113 kwargs: Union[ 1114 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1115 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
1118class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1119 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1120 1121 axes: Annotated[ 1122 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1123 ] = None 1124 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1125 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1126 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1127 To normalize each sample independently leave out the 'batch' axis. 1128 Default: Scale all axes jointly.""" 1129 1130 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1131 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
1134class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1135 """Subtract mean and divide by variance. 1136 1137 Examples: 1138 Subtract tensor mean and variance 1139 - in YAML 1140 ```yaml 1141 preprocessing: 1142 - id: zero_mean_unit_variance 1143 ``` 1144 - in Python 1145 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1146 """ 1147 1148 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1149 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1150 default_factory=ZeroMeanUnitVarianceKwargs 1151 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1154class ScaleRangeKwargs(ProcessingKwargs): 1155 """key word arguments for `ScaleRangeDescr` 1156 1157 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1158 this processing step normalizes data to the [0, 1] intervall. 1159 For other percentiles the normalized values will partially be outside the [0, 1] 1160 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1161 normalized values to a range. 1162 """ 1163 1164 axes: Annotated[ 1165 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1166 ] = None 1167 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1168 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1169 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1170 To normalize samples independently, leave out the "batch" axis. 1171 Default: Scale all axes jointly.""" 1172 1173 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1174 """The lower percentile used to determine the value to align with zero.""" 1175 1176 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1177 """The upper percentile used to determine the value to align with one. 1178 Has to be bigger than `min_percentile`. 1179 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1180 accepting percentiles specified in the range 0.0 to 1.0.""" 1181 1182 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1183 """Epsilon for numeric stability. 1184 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1185 with `v_lower,v_upper` values at the respective percentiles.""" 1186 1187 reference_tensor: Optional[TensorId] = None 1188 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1189 For any tensor in `inputs` only input tensor references are allowed.""" 1190 1191 @field_validator("max_percentile", mode="after") 1192 @classmethod 1193 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1194 if (min_p := info.data["min_percentile"]) >= value: 1195 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1196 1197 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1191 @field_validator("max_percentile", mode="after") 1192 @classmethod 1193 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1194 if (min_p := info.data["min_percentile"]) >= value: 1195 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1196 1197 return value
1200class ScaleRangeDescr(ProcessingDescrBase): 1201 """Scale with percentiles. 1202 1203 Examples: 1204 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1205 - in YAML 1206 ```yaml 1207 preprocessing: 1208 - id: scale_range 1209 kwargs: 1210 axes: ['y', 'x'] 1211 max_percentile: 99.8 1212 min_percentile: 5.0 1213 ``` 1214 - in Python 1215 >>> preprocessing = [ 1216 ... ScaleRangeDescr( 1217 ... kwargs=ScaleRangeKwargs( 1218 ... axes= (AxisId('y'), AxisId('x')), 1219 ... max_percentile= 99.8, 1220 ... min_percentile= 5.0, 1221 ... ) 1222 ... ), 1223 ... ClipDescr( 1224 ... kwargs=ClipKwargs( 1225 ... min=0.0, 1226 ... max=1.0, 1227 ... ) 1228 ... ), 1229 ... ] 1230 1231 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1232 - in YAML 1233 ```yaml 1234 preprocessing: 1235 - id: scale_range 1236 kwargs: 1237 axes: ['y', 'x'] 1238 max_percentile: 99.8 1239 min_percentile: 5.0 1240 - id: scale_range 1241 - id: clip 1242 kwargs: 1243 min: 0.0 1244 max: 1.0 1245 ``` 1246 - in Python 1247 >>> preprocessing = [ScaleRangeDescr( 1248 ... kwargs=ScaleRangeKwargs( 1249 ... axes= (AxisId('y'), AxisId('x')), 1250 ... max_percentile= 99.8, 1251 ... min_percentile= 5.0, 1252 ... ) 1253 ... )] 1254 1255 """ 1256 1257 id: Literal["scale_range"] = "scale_range" 1258 kwargs: ScaleRangeKwargs
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
1261class ScaleMeanVarianceKwargs(ProcessingKwargs): 1262 """key word arguments for `ScaleMeanVarianceKwargs`""" 1263 1264 reference_tensor: TensorId 1265 """Name of tensor to match.""" 1266 1267 axes: Annotated[ 1268 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1269 ] = None 1270 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1271 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1272 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1273 To normalize samples independently, leave out the 'batch' axis. 1274 Default: Scale all axes jointly.""" 1275 1276 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1277 """Epsilon for numeric stability: 1278 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
1281class ScaleMeanVarianceDescr(ProcessingDescrBase): 1282 """Scale a tensor's data distribution to match another tensor's mean/std. 1283 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1284 """ 1285 1286 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1287 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
1321class TensorDescrBase(Node, Generic[IO_AxisT]): 1322 id: TensorId 1323 """Tensor id. No duplicates are allowed.""" 1324 1325 description: Annotated[str, MaxLen(128)] = "" 1326 """free text description""" 1327 1328 axes: NotEmpty[Sequence[IO_AxisT]] 1329 """tensor axes""" 1330 1331 @property 1332 def shape(self): 1333 return tuple(a.size for a in self.axes) 1334 1335 @field_validator("axes", mode="after", check_fields=False) 1336 @classmethod 1337 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1338 batch_axes = [a for a in axes if a.type == "batch"] 1339 if len(batch_axes) > 1: 1340 raise ValueError( 1341 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1342 ) 1343 1344 seen_ids: Set[AxisId] = set() 1345 duplicate_axes_ids: Set[AxisId] = set() 1346 for a in axes: 1347 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1348 1349 if duplicate_axes_ids: 1350 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1351 1352 return axes 1353 1354 test_tensor: FileDescr 1355 """An example tensor to use for testing. 1356 Using the model with the test input tensors is expected to yield the test output tensors. 1357 Each test tensor has be a an ndarray in the 1358 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1359 The file extension must be '.npy'.""" 1360 1361 sample_tensor: Optional[FileDescr] = None 1362 """A sample tensor to illustrate a possible input/output for the model, 1363 The sample image primarily serves to inform a human user about an example use case 1364 and is typically stored as .hdf5, .png or .tiff. 1365 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1366 (numpy's `.npy` format is not supported). 1367 The image dimensionality has to match the number of axes specified in this tensor description. 1368 """ 1369 1370 @model_validator(mode="after") 1371 def _validate_sample_tensor(self) -> Self: 1372 if ( 1373 self.sample_tensor is None 1374 or not validation_context_var.get().perform_io_checks 1375 ): 1376 return self 1377 1378 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1379 tensor: NDArray[Any] = imread( 1380 local.path.read_bytes(), 1381 extension=PurePosixPath(local.original_file_name).suffix, 1382 ) 1383 n_dims = len(tensor.squeeze().shape) 1384 n_dims_min = n_dims_max = len(self.axes) 1385 1386 for a in self.axes: 1387 if isinstance(a, BatchAxis): 1388 n_dims_min -= 1 1389 elif isinstance(a.size, int): 1390 if a.size == 1: 1391 n_dims_min -= 1 1392 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1393 if a.size.min == 1: 1394 n_dims_min -= 1 1395 elif isinstance(a.size, SizeReference): 1396 if a.size.offset < 2: 1397 # size reference may result in singleton axis 1398 n_dims_min -= 1 1399 else: 1400 assert_never(a.size) 1401 1402 n_dims_min = max(0, n_dims_min) 1403 if n_dims < n_dims_min or n_dims > n_dims_max: 1404 raise ValueError( 1405 f"Expected sample tensor to have {n_dims_min} to" 1406 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1407 ) 1408 1409 return self 1410 1411 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1412 IntervalOrRatioDataDescr() 1413 ) 1414 """Description of the tensor's data values, optionally per channel. 1415 If specified per channel, the data `type` needs to match across channels.""" 1416 1417 @property 1418 def dtype( 1419 self, 1420 ) -> Literal[ 1421 "float32", 1422 "float64", 1423 "uint8", 1424 "int8", 1425 "uint16", 1426 "int16", 1427 "uint32", 1428 "int32", 1429 "uint64", 1430 "int64", 1431 "bool", 1432 ]: 1433 """dtype as specified under `data.type` or `data[i].type`""" 1434 if isinstance(self.data, collections.abc.Sequence): 1435 return self.data[0].type 1436 else: 1437 return self.data.type 1438 1439 @field_validator("data", mode="after") 1440 @classmethod 1441 def _check_data_type_across_channels( 1442 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1443 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1444 if not isinstance(value, list): 1445 return value 1446 1447 dtypes = {t.type for t in value} 1448 if len(dtypes) > 1: 1449 raise ValueError( 1450 "Tensor data descriptions per channel need to agree in their data" 1451 + f" `type`, but found {dtypes}." 1452 ) 1453 1454 return value 1455 1456 @model_validator(mode="after") 1457 def _check_data_matches_channelaxis(self) -> Self: 1458 if not isinstance(self.data, (list, tuple)): 1459 return self 1460 1461 for a in self.axes: 1462 if isinstance(a, ChannelAxis): 1463 size = a.size 1464 assert isinstance(size, int) 1465 break 1466 else: 1467 return self 1468 1469 if len(self.data) != size: 1470 raise ValueError( 1471 f"Got tensor data descriptions for {len(self.data)} channels, but" 1472 + f" '{a.id}' axis has size {size}." 1473 ) 1474 1475 return self 1476 1477 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1478 if len(array.shape) != len(self.axes): 1479 raise ValueError( 1480 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1481 + f" incompatible with {len(self.axes)} axes." 1482 ) 1483 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1417 @property 1418 def dtype( 1419 self, 1420 ) -> Literal[ 1421 "float32", 1422 "float64", 1423 "uint8", 1424 "int8", 1425 "uint16", 1426 "int16", 1427 "uint32", 1428 "int32", 1429 "uint64", 1430 "int64", 1431 "bool", 1432 ]: 1433 """dtype as specified under `data.type` or `data[i].type`""" 1434 if isinstance(self.data, collections.abc.Sequence): 1435 return self.data[0].type 1436 else: 1437 return self.data.type
dtype as specified under data.type
or data[i].type
1477 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1478 if len(array.shape) != len(self.axes): 1479 raise ValueError( 1480 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1481 + f" incompatible with {len(self.axes)} axes." 1482 ) 1483 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Inherited Members
1486class InputTensorDescr(TensorDescrBase[InputAxis]): 1487 id: TensorId = TensorId("input") 1488 """Input tensor id. 1489 No duplicates are allowed across all inputs and outputs.""" 1490 1491 optional: bool = False 1492 """indicates that this tensor may be `None`""" 1493 1494 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1495 """Description of how this input should be preprocessed. 1496 1497 notes: 1498 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1499 to ensure an input tensor's data type matches the input tensor's data description. 1500 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1501 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1502 changing the data type. 1503 """ 1504 1505 @model_validator(mode="after") 1506 def _validate_preprocessing_kwargs(self) -> Self: 1507 axes_ids = [a.id for a in self.axes] 1508 for p in self.preprocessing: 1509 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1510 if kwargs_axes is None: 1511 continue 1512 1513 if not isinstance(kwargs_axes, collections.abc.Sequence): 1514 raise ValueError( 1515 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1516 ) 1517 1518 if any(a not in axes_ids for a in kwargs_axes): 1519 raise ValueError( 1520 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1521 ) 1522 1523 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1524 dtype = self.data.type 1525 else: 1526 dtype = self.data[0].type 1527 1528 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1529 if not self.preprocessing or not isinstance( 1530 self.preprocessing[0], EnsureDtypeDescr 1531 ): 1532 self.preprocessing.insert( 1533 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1534 ) 1535 1536 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1537 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1538 self.preprocessing.append( 1539 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1540 ) 1541 1542 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
1545def convert_axes( 1546 axes: str, 1547 *, 1548 shape: Union[ 1549 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1550 ], 1551 tensor_type: Literal["input", "output"], 1552 halo: Optional[Sequence[int]], 1553 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1554): 1555 ret: List[AnyAxis] = [] 1556 for i, a in enumerate(axes): 1557 axis_type = _AXIS_TYPE_MAP.get(a, a) 1558 if axis_type == "batch": 1559 ret.append(BatchAxis()) 1560 continue 1561 1562 scale = 1.0 1563 if isinstance(shape, _ParameterizedInputShape_v0_4): 1564 if shape.step[i] == 0: 1565 size = shape.min[i] 1566 else: 1567 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1568 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1569 ref_t = str(shape.reference_tensor) 1570 if ref_t.count(".") == 1: 1571 t_id, orig_a_id = ref_t.split(".") 1572 else: 1573 t_id = ref_t 1574 orig_a_id = a 1575 1576 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1577 if not (orig_scale := shape.scale[i]): 1578 # old way to insert a new axis dimension 1579 size = int(2 * shape.offset[i]) 1580 else: 1581 scale = 1 / orig_scale 1582 if axis_type in ("channel", "index"): 1583 # these axes no longer have a scale 1584 offset_from_scale = orig_scale * size_refs.get( 1585 _TensorName_v0_4(t_id), {} 1586 ).get(orig_a_id, 0) 1587 else: 1588 offset_from_scale = 0 1589 size = SizeReference( 1590 tensor_id=TensorId(t_id), 1591 axis_id=AxisId(a_id), 1592 offset=int(offset_from_scale + 2 * shape.offset[i]), 1593 ) 1594 else: 1595 size = shape[i] 1596 1597 if axis_type == "time": 1598 if tensor_type == "input": 1599 ret.append(TimeInputAxis(size=size, scale=scale)) 1600 else: 1601 assert not isinstance(size, ParameterizedSize) 1602 if halo is None: 1603 ret.append(TimeOutputAxis(size=size, scale=scale)) 1604 else: 1605 assert not isinstance(size, int) 1606 ret.append( 1607 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1608 ) 1609 1610 elif axis_type == "index": 1611 if tensor_type == "input": 1612 ret.append(IndexInputAxis(size=size)) 1613 else: 1614 if isinstance(size, ParameterizedSize): 1615 size = DataDependentSize(min=size.min) 1616 1617 ret.append(IndexOutputAxis(size=size)) 1618 elif axis_type == "channel": 1619 assert not isinstance(size, ParameterizedSize) 1620 if isinstance(size, SizeReference): 1621 warnings.warn( 1622 "Conversion of channel size from an implicit output shape may be" 1623 + " wrong" 1624 ) 1625 ret.append( 1626 ChannelAxis( 1627 channel_names=[ 1628 Identifier(f"channel{i}") for i in range(size.offset) 1629 ] 1630 ) 1631 ) 1632 else: 1633 ret.append( 1634 ChannelAxis( 1635 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1636 ) 1637 ) 1638 elif axis_type == "space": 1639 if tensor_type == "input": 1640 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1641 else: 1642 assert not isinstance(size, ParameterizedSize) 1643 if halo is None or halo[i] == 0: 1644 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1645 elif isinstance(size, int): 1646 raise NotImplementedError( 1647 f"output axis with halo and fixed size (here {size}) not allowed" 1648 ) 1649 else: 1650 ret.append( 1651 SpaceOutputAxisWithHalo( 1652 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1653 ) 1654 ) 1655 1656 return ret
1831class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1832 id: TensorId = TensorId("output") 1833 """Output tensor id. 1834 No duplicates are allowed across all inputs and outputs.""" 1835 1836 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1837 """Description of how this output should be postprocessed. 1838 1839 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1840 If not given this is added to cast to this tensor's `data.type`. 1841 """ 1842 1843 @model_validator(mode="after") 1844 def _validate_postprocessing_kwargs(self) -> Self: 1845 axes_ids = [a.id for a in self.axes] 1846 for p in self.postprocessing: 1847 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1848 if kwargs_axes is None: 1849 continue 1850 1851 if not isinstance(kwargs_axes, collections.abc.Sequence): 1852 raise ValueError( 1853 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1854 ) 1855 1856 if any(a not in axes_ids for a in kwargs_axes): 1857 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1858 1859 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1860 dtype = self.data.type 1861 else: 1862 dtype = self.data[0].type 1863 1864 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1865 if not self.postprocessing or not isinstance( 1866 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1867 ): 1868 self.postprocessing.append( 1869 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1870 ) 1871 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
1921def validate_tensors( 1922 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 1923 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor' 1924): 1925 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 1926 1927 def e_msg(d: TensorDescr): 1928 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 1929 1930 for descr, array in tensors.values(): 1931 try: 1932 axis_sizes = descr.get_axis_sizes_for_array(array) 1933 except ValueError as e: 1934 raise ValueError(f"{e_msg(descr)} {e}") 1935 else: 1936 all_tensor_axes[descr.id] = { 1937 a.id: (a, axis_sizes[a.id]) for a in descr.axes 1938 } 1939 1940 for descr, array in tensors.values(): 1941 if array.dtype.name != descr.dtype: 1942 raise ValueError( 1943 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 1944 + f" match described dtype '{descr.dtype}'" 1945 ) 1946 1947 for a in descr.axes: 1948 actual_size = all_tensor_axes[descr.id][a.id][1] 1949 if a.size is None: 1950 continue 1951 1952 if isinstance(a.size, int): 1953 if actual_size != a.size: 1954 raise ValueError( 1955 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 1956 + f"has incompatible size {actual_size}, expected {a.size}" 1957 ) 1958 elif isinstance(a.size, ParameterizedSize): 1959 _ = a.size.validate_size(actual_size) 1960 elif isinstance(a.size, DataDependentSize): 1961 _ = a.size.validate_size(actual_size) 1962 elif isinstance(a.size, SizeReference): 1963 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 1964 if ref_tensor_axes is None: 1965 raise ValueError( 1966 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 1967 + f" reference '{a.size.tensor_id}'" 1968 ) 1969 1970 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 1971 if ref_axis is None or ref_size is None: 1972 raise ValueError( 1973 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 1974 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 1975 ) 1976 1977 if a.unit != ref_axis.unit: 1978 raise ValueError( 1979 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 1980 + " axis and reference axis to have the same `unit`, but" 1981 + f" {a.unit}!={ref_axis.unit}" 1982 ) 1983 1984 if actual_size != ( 1985 expected_size := ( 1986 ref_size * ref_axis.scale / a.scale + a.size.offset 1987 ) 1988 ): 1989 raise ValueError( 1990 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 1991 + f" {actual_size} invalid for referenced size {ref_size};" 1992 + f" expected {expected_size}" 1993 ) 1994 else: 1995 assert_never(a.size)
1998class EnvironmentFileDescr(FileDescr): 1999 source: Annotated[ 2000 ImportantFileSource, 2001 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2002 Field( 2003 examples=["environment.yaml"], 2004 ), 2005 ] 2006 """∈📦 Conda environment file. 2007 Allows to specify custom dependencies, see conda docs: 2008 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2009 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2010 """
∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:
2093class WeightsEntryDescrBase(FileDescr): 2094 type: ClassVar[WeightsFormat] 2095 weights_format_name: ClassVar[str] # human readable 2096 2097 source: ImportantFileSource 2098 """∈📦 The weights file.""" 2099 2100 authors: Optional[List[Author]] = None 2101 """Authors 2102 Either the person(s) that have trained this model resulting in the original weights file. 2103 (If this is the initial weights entry, i.e. it does not have a `parent`) 2104 Or the person(s) who have converted the weights to this weights format. 2105 (If this is a child weight, i.e. it has a `parent` field) 2106 """ 2107 2108 parent: Annotated[ 2109 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2110 ] = None 2111 """The source weights these weights were converted from. 2112 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2113 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2114 All weight entries except one (the initial set of weights resulting from training the model), 2115 need to have this field.""" 2116 2117 @model_validator(mode="after") 2118 def check_parent_is_not_self(self) -> Self: 2119 if self.type == self.parent: 2120 raise ValueError("Weights entry can't be it's own parent.") 2121 2122 return self
∈📦 The weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
2125class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2126 type = "keras_hdf5" 2127 weights_format_name: ClassVar[str] = "Keras HDF5" 2128 tensorflow_version: Version 2129 """TensorFlow version used to create these weights."""
TensorFlow version used to create these weights.
2139class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2140 type = "pytorch_state_dict" 2141 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2142 architecture: ArchitectureDescr 2143 pytorch_version: Version 2144 """Version of the PyTorch library used. 2145 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2146 """ 2147 dependencies: Optional[EnvironmentFileDescr] = None 2148 """Custom depencies beyond pytorch. 2149 The conda environment file should include pytorch and any version pinning has to be compatible with 2150 `pytorch_version`. 2151 """
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch.
The conda environment file should include pytorch and any version pinning has to be compatible with
pytorch_version
.
2154class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2155 type = "tensorflow_js" 2156 weights_format_name: ClassVar[str] = "Tensorflow.js" 2157 tensorflow_version: Version 2158 """Version of the TensorFlow library used.""" 2159 2160 source: ImportantFileSource 2161 """∈📦 The multi-file weights. 2162 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
2165class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2166 type = "tensorflow_saved_model_bundle" 2167 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2168 tensorflow_version: Version 2169 """Version of the TensorFlow library used.""" 2170 2171 dependencies: Optional[EnvironmentFileDescr] = None 2172 """Custom dependencies beyond tensorflow. 2173 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 2174 2175 source: ImportantFileSource 2176 """∈📦 The multi-file weights. 2177 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow.
Should include tensorflow and any version pinning has to be compatible with tensorflow_version
.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
2180class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2181 type = "torchscript" 2182 weights_format_name: ClassVar[str] = "TorchScript" 2183 pytorch_version: Version 2184 """Version of the PyTorch library used."""
Version of the PyTorch library used.
2187class WeightsDescr(Node): 2188 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2189 onnx: Optional[OnnxWeightsDescr] = None 2190 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2191 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2192 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2193 None 2194 ) 2195 torchscript: Optional[TorchscriptWeightsDescr] = None 2196 2197 @model_validator(mode="after") 2198 def check_entries(self) -> Self: 2199 entries = {wtype for wtype, entry in self if entry is not None} 2200 2201 if not entries: 2202 raise ValueError("Missing weights entry") 2203 2204 entries_wo_parent = { 2205 wtype 2206 for wtype, entry in self 2207 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2208 } 2209 if len(entries_wo_parent) != 1: 2210 issue_warning( 2211 "Exactly one weights entry may not specify the `parent` field (got" 2212 + " {value}). That entry is considered the original set of model weights." 2213 + " Other weight formats are created through conversion of the orignal or" 2214 + " already converted weights. They have to reference the weights format" 2215 + " they were converted from as their `parent`.", 2216 value=len(entries_wo_parent), 2217 field="weights", 2218 ) 2219 2220 for wtype, entry in self: 2221 if entry is None: 2222 continue 2223 2224 assert hasattr(entry, "type") 2225 assert hasattr(entry, "parent") 2226 assert wtype == entry.type 2227 if ( 2228 entry.parent is not None and entry.parent not in entries 2229 ): # self reference checked for `parent` field 2230 raise ValueError( 2231 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2232 + f" formats: {entries}" 2233 ) 2234 2235 return self 2236 2237 def __getitem__( 2238 self, 2239 key: Literal[ 2240 "keras_hdf5", 2241 "onnx", 2242 "pytorch_state_dict", 2243 "tensorflow_js", 2244 "tensorflow_saved_model_bundle", 2245 "torchscript", 2246 ], 2247 ): 2248 if key == "keras_hdf5": 2249 ret = self.keras_hdf5 2250 elif key == "onnx": 2251 ret = self.onnx 2252 elif key == "pytorch_state_dict": 2253 ret = self.pytorch_state_dict 2254 elif key == "tensorflow_js": 2255 ret = self.tensorflow_js 2256 elif key == "tensorflow_saved_model_bundle": 2257 ret = self.tensorflow_saved_model_bundle 2258 elif key == "torchscript": 2259 ret = self.torchscript 2260 else: 2261 raise KeyError(key) 2262 2263 if ret is None: 2264 raise KeyError(key) 2265 2266 return ret 2267 2268 @property 2269 def available_formats(self): 2270 return { 2271 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2272 **({} if self.onnx is None else {"onnx": self.onnx}), 2273 **( 2274 {} 2275 if self.pytorch_state_dict is None 2276 else {"pytorch_state_dict": self.pytorch_state_dict} 2277 ), 2278 **( 2279 {} 2280 if self.tensorflow_js is None 2281 else {"tensorflow_js": self.tensorflow_js} 2282 ), 2283 **( 2284 {} 2285 if self.tensorflow_saved_model_bundle is None 2286 else { 2287 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2288 } 2289 ), 2290 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2291 } 2292 2293 @property 2294 def missing_formats(self): 2295 return { 2296 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2297 }
2197 @model_validator(mode="after") 2198 def check_entries(self) -> Self: 2199 entries = {wtype for wtype, entry in self if entry is not None} 2200 2201 if not entries: 2202 raise ValueError("Missing weights entry") 2203 2204 entries_wo_parent = { 2205 wtype 2206 for wtype, entry in self 2207 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2208 } 2209 if len(entries_wo_parent) != 1: 2210 issue_warning( 2211 "Exactly one weights entry may not specify the `parent` field (got" 2212 + " {value}). That entry is considered the original set of model weights." 2213 + " Other weight formats are created through conversion of the orignal or" 2214 + " already converted weights. They have to reference the weights format" 2215 + " they were converted from as their `parent`.", 2216 value=len(entries_wo_parent), 2217 field="weights", 2218 ) 2219 2220 for wtype, entry in self: 2221 if entry is None: 2222 continue 2223 2224 assert hasattr(entry, "type") 2225 assert hasattr(entry, "parent") 2226 assert wtype == entry.type 2227 if ( 2228 entry.parent is not None and entry.parent not in entries 2229 ): # self reference checked for `parent` field 2230 raise ValueError( 2231 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2232 + f" formats: {entries}" 2233 ) 2234 2235 return self
2268 @property 2269 def available_formats(self): 2270 return { 2271 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2272 **({} if self.onnx is None else {"onnx": self.onnx}), 2273 **( 2274 {} 2275 if self.pytorch_state_dict is None 2276 else {"pytorch_state_dict": self.pytorch_state_dict} 2277 ), 2278 **( 2279 {} 2280 if self.tensorflow_js is None 2281 else {"tensorflow_js": self.tensorflow_js} 2282 ), 2283 **( 2284 {} 2285 if self.tensorflow_saved_model_bundle is None 2286 else { 2287 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2288 } 2289 ), 2290 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2291 }
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2304class LinkedModel(LinkedResourceBase): 2305 """Reference to a bioimage.io model.""" 2306 2307 id: ModelId 2308 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
2330class ModelDescr(GenericModelDescrBase): 2331 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2332 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2333 """ 2334 2335 format_version: Literal["0.5.3"] = "0.5.3" 2336 """Version of the bioimage.io model description specification used. 2337 When creating a new model always use the latest micro/patch version described here. 2338 The `format_version` is important for any consumer software to understand how to parse the fields. 2339 """ 2340 2341 type: Literal["model"] = "model" 2342 """Specialized resource type 'model'""" 2343 2344 id: Optional[ModelId] = None 2345 """bioimage.io-wide unique resource identifier 2346 assigned by bioimage.io; version **un**specific.""" 2347 2348 authors: NotEmpty[List[Author]] 2349 """The authors are the creators of the model RDF and the primary points of contact.""" 2350 2351 documentation: Annotated[ 2352 DocumentationSource, 2353 Field( 2354 examples=[ 2355 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2356 "README.md", 2357 ], 2358 ), 2359 ] 2360 """∈📦 URL or relative path to a markdown file with additional documentation. 2361 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2362 The documentation should include a '#[#] Validation' (sub)section 2363 with details on how to quantitatively validate the model on unseen data.""" 2364 2365 @field_validator("documentation", mode="after") 2366 @classmethod 2367 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2368 if not validation_context_var.get().perform_io_checks: 2369 return value 2370 2371 doc_path = download(value).path 2372 doc_content = doc_path.read_text(encoding="utf-8") 2373 assert isinstance(doc_content, str) 2374 if not re.match("#.*[vV]alidation", doc_content): 2375 issue_warning( 2376 "No '# Validation' (sub)section found in {value}.", 2377 value=value, 2378 field="documentation", 2379 ) 2380 2381 return value 2382 2383 inputs: NotEmpty[Sequence[InputTensorDescr]] 2384 """Describes the input tensors expected by this model.""" 2385 2386 @field_validator("inputs", mode="after") 2387 @classmethod 2388 def _validate_input_axes( 2389 cls, inputs: Sequence[InputTensorDescr] 2390 ) -> Sequence[InputTensorDescr]: 2391 input_size_refs = cls._get_axes_with_independent_size(inputs) 2392 2393 for i, ipt in enumerate(inputs): 2394 valid_independent_refs: Dict[ 2395 Tuple[TensorId, AxisId], 2396 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2397 ] = { 2398 **{ 2399 (ipt.id, a.id): (ipt, a, a.size) 2400 for a in ipt.axes 2401 if not isinstance(a, BatchAxis) 2402 and isinstance(a.size, (int, ParameterizedSize)) 2403 }, 2404 **input_size_refs, 2405 } 2406 for a, ax in enumerate(ipt.axes): 2407 cls._validate_axis( 2408 "inputs", 2409 i=i, 2410 tensor_id=ipt.id, 2411 a=a, 2412 axis=ax, 2413 valid_independent_refs=valid_independent_refs, 2414 ) 2415 return inputs 2416 2417 @staticmethod 2418 def _validate_axis( 2419 field_name: str, 2420 i: int, 2421 tensor_id: TensorId, 2422 a: int, 2423 axis: AnyAxis, 2424 valid_independent_refs: Dict[ 2425 Tuple[TensorId, AxisId], 2426 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2427 ], 2428 ): 2429 if isinstance(axis, BatchAxis) or isinstance( 2430 axis.size, (int, ParameterizedSize, DataDependentSize) 2431 ): 2432 return 2433 elif not isinstance(axis.size, SizeReference): 2434 assert_never(axis.size) 2435 2436 # validate axis.size SizeReference 2437 ref = (axis.size.tensor_id, axis.size.axis_id) 2438 if ref not in valid_independent_refs: 2439 raise ValueError( 2440 "Invalid tensor axis reference at" 2441 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2442 ) 2443 if ref == (tensor_id, axis.id): 2444 raise ValueError( 2445 "Self-referencing not allowed for" 2446 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2447 ) 2448 if axis.type == "channel": 2449 if valid_independent_refs[ref][1].type != "channel": 2450 raise ValueError( 2451 "A channel axis' size may only reference another fixed size" 2452 + " channel axis." 2453 ) 2454 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2455 ref_size = valid_independent_refs[ref][2] 2456 assert isinstance(ref_size, int), ( 2457 "channel axis ref (another channel axis) has to specify fixed" 2458 + " size" 2459 ) 2460 generated_channel_names = [ 2461 Identifier(axis.channel_names.format(i=i)) 2462 for i in range(1, ref_size + 1) 2463 ] 2464 axis.channel_names = generated_channel_names 2465 2466 if (ax_unit := getattr(axis, "unit", None)) != ( 2467 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2468 ): 2469 raise ValueError( 2470 "The units of an axis and its reference axis need to match, but" 2471 + f" '{ax_unit}' != '{ref_unit}'." 2472 ) 2473 ref_axis = valid_independent_refs[ref][1] 2474 if isinstance(ref_axis, BatchAxis): 2475 raise ValueError( 2476 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2477 + " (a batch axis is not allowed as reference)." 2478 ) 2479 2480 if isinstance(axis, WithHalo): 2481 min_size = axis.size.get_size(axis, ref_axis, n=0) 2482 if (min_size - 2 * axis.halo) < 1: 2483 raise ValueError( 2484 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2485 + f" {axis.halo}." 2486 ) 2487 2488 input_halo = axis.halo * axis.scale / ref_axis.scale 2489 if input_halo != int(input_halo) or input_halo % 2 == 1: 2490 raise ValueError( 2491 f"input_halo {input_halo} (output_halo {axis.halo} *" 2492 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2493 + f" is not an even integer for {tensor_id}.{axis.id}." 2494 ) 2495 2496 @model_validator(mode="after") 2497 def _validate_test_tensors(self) -> Self: 2498 if not validation_context_var.get().perform_io_checks: 2499 return self 2500 2501 test_arrays = [ 2502 load_array(descr.test_tensor.download().path) 2503 for descr in chain(self.inputs, self.outputs) 2504 ] 2505 tensors = { 2506 descr.id: (descr, array) 2507 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays) 2508 } 2509 validate_tensors(tensors, tensor_origin="test_tensor") 2510 return self 2511 2512 @model_validator(mode="after") 2513 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2514 ipt_refs = {t.id for t in self.inputs} 2515 out_refs = {t.id for t in self.outputs} 2516 for ipt in self.inputs: 2517 for p in ipt.preprocessing: 2518 ref = p.kwargs.get("reference_tensor") 2519 if ref is None: 2520 continue 2521 if ref not in ipt_refs: 2522 raise ValueError( 2523 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2524 + f" references are: {ipt_refs}." 2525 ) 2526 2527 for out in self.outputs: 2528 for p in out.postprocessing: 2529 ref = p.kwargs.get("reference_tensor") 2530 if ref is None: 2531 continue 2532 2533 if ref not in ipt_refs and ref not in out_refs: 2534 raise ValueError( 2535 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2536 + f" are: {ipt_refs | out_refs}." 2537 ) 2538 2539 return self 2540 2541 # TODO: use validate funcs in validate_test_tensors 2542 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2543 2544 name: Annotated[ 2545 Annotated[ 2546 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()") 2547 ], 2548 MinLen(5), 2549 MaxLen(128), 2550 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2551 ] 2552 """A human-readable name of this model. 2553 It should be no longer than 64 characters 2554 and may only contain letter, number, underscore, minus, parentheses and spaces. 2555 We recommend to chose a name that refers to the model's task and image modality. 2556 """ 2557 2558 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2559 """Describes the output tensors.""" 2560 2561 @field_validator("outputs", mode="after") 2562 @classmethod 2563 def _validate_tensor_ids( 2564 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2565 ) -> Sequence[OutputTensorDescr]: 2566 tensor_ids = [ 2567 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2568 ] 2569 duplicate_tensor_ids: List[str] = [] 2570 seen: Set[str] = set() 2571 for t in tensor_ids: 2572 if t in seen: 2573 duplicate_tensor_ids.append(t) 2574 2575 seen.add(t) 2576 2577 if duplicate_tensor_ids: 2578 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2579 2580 return outputs 2581 2582 @staticmethod 2583 def _get_axes_with_parameterized_size( 2584 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2585 ): 2586 return { 2587 f"{t.id}.{a.id}": (t, a, a.size) 2588 for t in io 2589 for a in t.axes 2590 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2591 } 2592 2593 @staticmethod 2594 def _get_axes_with_independent_size( 2595 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2596 ): 2597 return { 2598 (t.id, a.id): (t, a, a.size) 2599 for t in io 2600 for a in t.axes 2601 if not isinstance(a, BatchAxis) 2602 and isinstance(a.size, (int, ParameterizedSize)) 2603 } 2604 2605 @field_validator("outputs", mode="after") 2606 @classmethod 2607 def _validate_output_axes( 2608 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2609 ) -> List[OutputTensorDescr]: 2610 input_size_refs = cls._get_axes_with_independent_size( 2611 info.data.get("inputs", []) 2612 ) 2613 output_size_refs = cls._get_axes_with_independent_size(outputs) 2614 2615 for i, out in enumerate(outputs): 2616 valid_independent_refs: Dict[ 2617 Tuple[TensorId, AxisId], 2618 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2619 ] = { 2620 **{ 2621 (out.id, a.id): (out, a, a.size) 2622 for a in out.axes 2623 if not isinstance(a, BatchAxis) 2624 and isinstance(a.size, (int, ParameterizedSize)) 2625 }, 2626 **input_size_refs, 2627 **output_size_refs, 2628 } 2629 for a, ax in enumerate(out.axes): 2630 cls._validate_axis( 2631 "outputs", 2632 i, 2633 out.id, 2634 a, 2635 ax, 2636 valid_independent_refs=valid_independent_refs, 2637 ) 2638 2639 return outputs 2640 2641 packaged_by: List[Author] = Field(default_factory=list) 2642 """The persons that have packaged and uploaded this model. 2643 Only required if those persons differ from the `authors`.""" 2644 2645 parent: Optional[LinkedModel] = None 2646 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2647 2648 # todo: add parent self check once we have `id` 2649 # @model_validator(mode="after") 2650 # def validate_parent_is_not_self(self) -> Self: 2651 # if self.parent is not None and self.parent == self.id: 2652 # raise ValueError("The model may not reference itself as parent model") 2653 2654 # return self 2655 2656 run_mode: Annotated[ 2657 Optional[RunMode], 2658 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2659 ] = None 2660 """Custom run mode for this model: for more complex prediction procedures like test time 2661 data augmentation that currently cannot be expressed in the specification. 2662 No standard run modes are defined yet.""" 2663 2664 timestamp: Datetime = Datetime(datetime.now()) 2665 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2666 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2667 (In Python a datetime object is valid, too).""" 2668 2669 training_data: Annotated[ 2670 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2671 Field(union_mode="left_to_right"), 2672 ] = None 2673 """The dataset used to train this model""" 2674 2675 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2676 """The weights for this model. 2677 Weights can be given for different formats, but should otherwise be equivalent. 2678 The available weight formats determine which consumers can use this model.""" 2679 2680 @model_validator(mode="after") 2681 def _add_default_cover(self) -> Self: 2682 if not validation_context_var.get().perform_io_checks or self.covers: 2683 return self 2684 2685 try: 2686 generated_covers = generate_covers( 2687 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2688 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2689 ) 2690 except Exception as e: 2691 issue_warning( 2692 "Failed to generate cover image(s): {e}", 2693 value=self.covers, 2694 msg_context=dict(e=e), 2695 field="covers", 2696 ) 2697 else: 2698 self.covers.extend(generated_covers) 2699 2700 return self 2701 2702 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2703 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2704 assert all(isinstance(d, np.ndarray) for d in data) 2705 return data 2706 2707 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2708 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2709 assert all(isinstance(d, np.ndarray) for d in data) 2710 return data 2711 2712 @staticmethod 2713 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2714 batch_size = 1 2715 tensor_with_batchsize: Optional[TensorId] = None 2716 for tid in tensor_sizes: 2717 for aid, s in tensor_sizes[tid].items(): 2718 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2719 continue 2720 2721 if batch_size != 1: 2722 assert tensor_with_batchsize is not None 2723 raise ValueError( 2724 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2725 ) 2726 2727 batch_size = s 2728 tensor_with_batchsize = tid 2729 2730 return batch_size 2731 2732 def get_output_tensor_sizes( 2733 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2734 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2735 """Returns the tensor output sizes for given **input_sizes**. 2736 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2737 Otherwise it might be larger than the actual (valid) output""" 2738 batch_size = self.get_batch_size(input_sizes) 2739 ns = self.get_ns(input_sizes) 2740 2741 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2742 return tensor_sizes.outputs 2743 2744 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2745 """get parameter `n` for each parameterized axis 2746 such that the valid input size is >= the given input size""" 2747 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2748 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2749 for tid in input_sizes: 2750 for aid, s in input_sizes[tid].items(): 2751 size_descr = axes[tid][aid].size 2752 if isinstance(size_descr, ParameterizedSize): 2753 ret[(tid, aid)] = size_descr.get_n(s) 2754 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2755 pass 2756 else: 2757 assert_never(size_descr) 2758 2759 return ret 2760 2761 def get_tensor_sizes( 2762 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2763 ) -> _TensorSizes: 2764 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2765 return _TensorSizes( 2766 { 2767 t: { 2768 aa: axis_sizes.inputs[(tt, aa)] 2769 for tt, aa in axis_sizes.inputs 2770 if tt == t 2771 } 2772 for t in {tt for tt, _ in axis_sizes.inputs} 2773 }, 2774 { 2775 t: { 2776 aa: axis_sizes.outputs[(tt, aa)] 2777 for tt, aa in axis_sizes.outputs 2778 if tt == t 2779 } 2780 for t in {tt for tt, _ in axis_sizes.outputs} 2781 }, 2782 ) 2783 2784 def get_axis_sizes( 2785 self, 2786 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2787 batch_size: Optional[int] = None, 2788 *, 2789 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2790 ) -> _AxisSizes: 2791 """Determine input and output block shape for scale factors **ns** 2792 of parameterized input sizes. 2793 2794 Args: 2795 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2796 that is parameterized as `size = min + n * step`. 2797 batch_size: The desired size of the batch dimension. 2798 If given **batch_size** overwrites any batch size present in 2799 **max_input_shape**. Default 1. 2800 max_input_shape: Limits the derived block shapes. 2801 Each axis for which the input size, parameterized by `n`, is larger 2802 than **max_input_shape** is set to the minimal value `n_min` for which 2803 this is still true. 2804 Use this for small input samples or large values of **ns**. 2805 Or simply whenever you know the full input shape. 2806 2807 Returns: 2808 Resolved axis sizes for model inputs and outputs. 2809 """ 2810 max_input_shape = max_input_shape or {} 2811 if batch_size is None: 2812 for (_t_id, a_id), s in max_input_shape.items(): 2813 if a_id == BATCH_AXIS_ID: 2814 batch_size = s 2815 break 2816 else: 2817 batch_size = 1 2818 2819 all_axes = { 2820 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2821 } 2822 2823 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2824 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2825 2826 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2827 if isinstance(a, BatchAxis): 2828 if (t_descr.id, a.id) in ns: 2829 logger.warning( 2830 "Ignoring unexpected size increment factor (n) for batch axis" 2831 + " of tensor '{}'.", 2832 t_descr.id, 2833 ) 2834 return batch_size 2835 elif isinstance(a.size, int): 2836 if (t_descr.id, a.id) in ns: 2837 logger.warning( 2838 "Ignoring unexpected size increment factor (n) for fixed size" 2839 + " axis '{}' of tensor '{}'.", 2840 a.id, 2841 t_descr.id, 2842 ) 2843 return a.size 2844 elif isinstance(a.size, ParameterizedSize): 2845 if (t_descr.id, a.id) not in ns: 2846 raise ValueError( 2847 "Size increment factor (n) missing for parametrized axis" 2848 + f" '{a.id}' of tensor '{t_descr.id}'." 2849 ) 2850 n = ns[(t_descr.id, a.id)] 2851 s_max = max_input_shape.get((t_descr.id, a.id)) 2852 if s_max is not None: 2853 n = min(n, a.size.get_n(s_max)) 2854 2855 return a.size.get_size(n) 2856 2857 elif isinstance(a.size, SizeReference): 2858 if (t_descr.id, a.id) in ns: 2859 logger.warning( 2860 "Ignoring unexpected size increment factor (n) for axis '{}'" 2861 + " of tensor '{}' with size reference.", 2862 a.id, 2863 t_descr.id, 2864 ) 2865 assert not isinstance(a, BatchAxis) 2866 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2867 assert not isinstance(ref_axis, BatchAxis) 2868 ref_key = (a.size.tensor_id, a.size.axis_id) 2869 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2870 assert ref_size is not None, ref_key 2871 assert not isinstance(ref_size, _DataDepSize), ref_key 2872 return a.size.get_size( 2873 axis=a, 2874 ref_axis=ref_axis, 2875 ref_size=ref_size, 2876 ) 2877 elif isinstance(a.size, DataDependentSize): 2878 if (t_descr.id, a.id) in ns: 2879 logger.warning( 2880 "Ignoring unexpected increment factor (n) for data dependent" 2881 + " size axis '{}' of tensor '{}'.", 2882 a.id, 2883 t_descr.id, 2884 ) 2885 return _DataDepSize(a.size.min, a.size.max) 2886 else: 2887 assert_never(a.size) 2888 2889 # first resolve all , but the `SizeReference` input sizes 2890 for t_descr in self.inputs: 2891 for a in t_descr.axes: 2892 if not isinstance(a.size, SizeReference): 2893 s = get_axis_size(a) 2894 assert not isinstance(s, _DataDepSize) 2895 inputs[t_descr.id, a.id] = s 2896 2897 # resolve all other input axis sizes 2898 for t_descr in self.inputs: 2899 for a in t_descr.axes: 2900 if isinstance(a.size, SizeReference): 2901 s = get_axis_size(a) 2902 assert not isinstance(s, _DataDepSize) 2903 inputs[t_descr.id, a.id] = s 2904 2905 # resolve all output axis sizes 2906 for t_descr in self.outputs: 2907 for a in t_descr.axes: 2908 assert not isinstance(a.size, ParameterizedSize) 2909 s = get_axis_size(a) 2910 outputs[t_descr.id, a.id] = s 2911 2912 return _AxisSizes(inputs=inputs, outputs=outputs) 2913 2914 @model_validator(mode="before") 2915 @classmethod 2916 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 2917 if ( 2918 data.get("type") == "model" 2919 and isinstance(fv := data.get("format_version"), str) 2920 and fv.count(".") == 2 2921 ): 2922 fv_parts = fv.split(".") 2923 if any(not p.isdigit() for p in fv_parts): 2924 return data 2925 2926 fv_tuple = tuple(map(int, fv_parts)) 2927 2928 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 2929 if fv_tuple[:2] in ((0, 3), (0, 4)): 2930 m04 = _ModelDescr_v0_4.load(data) 2931 if not isinstance(m04, InvalidDescr): 2932 return _model_conv.convert_as_dict(m04) 2933 elif fv_tuple[:2] == (0, 5): 2934 # bump patch version 2935 data["format_version"] = cls.implemented_format_version 2936 2937 return data
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
Version of the bioimage.io model description specification used.
When creating a new model always use the latest micro/patch version described here.
The format_version
is important for any consumer software to understand how to parse the fields.
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2712 @staticmethod 2713 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2714 batch_size = 1 2715 tensor_with_batchsize: Optional[TensorId] = None 2716 for tid in tensor_sizes: 2717 for aid, s in tensor_sizes[tid].items(): 2718 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2719 continue 2720 2721 if batch_size != 1: 2722 assert tensor_with_batchsize is not None 2723 raise ValueError( 2724 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2725 ) 2726 2727 batch_size = s 2728 tensor_with_batchsize = tid 2729 2730 return batch_size
2732 def get_output_tensor_sizes( 2733 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2734 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2735 """Returns the tensor output sizes for given **input_sizes**. 2736 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2737 Otherwise it might be larger than the actual (valid) output""" 2738 batch_size = self.get_batch_size(input_sizes) 2739 ns = self.get_ns(input_sizes) 2740 2741 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2742 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2744 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2745 """get parameter `n` for each parameterized axis 2746 such that the valid input size is >= the given input size""" 2747 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2748 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2749 for tid in input_sizes: 2750 for aid, s in input_sizes[tid].items(): 2751 size_descr = axes[tid][aid].size 2752 if isinstance(size_descr, ParameterizedSize): 2753 ret[(tid, aid)] = size_descr.get_n(s) 2754 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2755 pass 2756 else: 2757 assert_never(size_descr) 2758 2759 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2761 def get_tensor_sizes( 2762 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2763 ) -> _TensorSizes: 2764 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2765 return _TensorSizes( 2766 { 2767 t: { 2768 aa: axis_sizes.inputs[(tt, aa)] 2769 for tt, aa in axis_sizes.inputs 2770 if tt == t 2771 } 2772 for t in {tt for tt, _ in axis_sizes.inputs} 2773 }, 2774 { 2775 t: { 2776 aa: axis_sizes.outputs[(tt, aa)] 2777 for tt, aa in axis_sizes.outputs 2778 if tt == t 2779 } 2780 for t in {tt for tt, _ in axis_sizes.outputs} 2781 }, 2782 )
2784 def get_axis_sizes( 2785 self, 2786 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2787 batch_size: Optional[int] = None, 2788 *, 2789 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2790 ) -> _AxisSizes: 2791 """Determine input and output block shape for scale factors **ns** 2792 of parameterized input sizes. 2793 2794 Args: 2795 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2796 that is parameterized as `size = min + n * step`. 2797 batch_size: The desired size of the batch dimension. 2798 If given **batch_size** overwrites any batch size present in 2799 **max_input_shape**. Default 1. 2800 max_input_shape: Limits the derived block shapes. 2801 Each axis for which the input size, parameterized by `n`, is larger 2802 than **max_input_shape** is set to the minimal value `n_min` for which 2803 this is still true. 2804 Use this for small input samples or large values of **ns**. 2805 Or simply whenever you know the full input shape. 2806 2807 Returns: 2808 Resolved axis sizes for model inputs and outputs. 2809 """ 2810 max_input_shape = max_input_shape or {} 2811 if batch_size is None: 2812 for (_t_id, a_id), s in max_input_shape.items(): 2813 if a_id == BATCH_AXIS_ID: 2814 batch_size = s 2815 break 2816 else: 2817 batch_size = 1 2818 2819 all_axes = { 2820 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2821 } 2822 2823 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2824 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2825 2826 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2827 if isinstance(a, BatchAxis): 2828 if (t_descr.id, a.id) in ns: 2829 logger.warning( 2830 "Ignoring unexpected size increment factor (n) for batch axis" 2831 + " of tensor '{}'.", 2832 t_descr.id, 2833 ) 2834 return batch_size 2835 elif isinstance(a.size, int): 2836 if (t_descr.id, a.id) in ns: 2837 logger.warning( 2838 "Ignoring unexpected size increment factor (n) for fixed size" 2839 + " axis '{}' of tensor '{}'.", 2840 a.id, 2841 t_descr.id, 2842 ) 2843 return a.size 2844 elif isinstance(a.size, ParameterizedSize): 2845 if (t_descr.id, a.id) not in ns: 2846 raise ValueError( 2847 "Size increment factor (n) missing for parametrized axis" 2848 + f" '{a.id}' of tensor '{t_descr.id}'." 2849 ) 2850 n = ns[(t_descr.id, a.id)] 2851 s_max = max_input_shape.get((t_descr.id, a.id)) 2852 if s_max is not None: 2853 n = min(n, a.size.get_n(s_max)) 2854 2855 return a.size.get_size(n) 2856 2857 elif isinstance(a.size, SizeReference): 2858 if (t_descr.id, a.id) in ns: 2859 logger.warning( 2860 "Ignoring unexpected size increment factor (n) for axis '{}'" 2861 + " of tensor '{}' with size reference.", 2862 a.id, 2863 t_descr.id, 2864 ) 2865 assert not isinstance(a, BatchAxis) 2866 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2867 assert not isinstance(ref_axis, BatchAxis) 2868 ref_key = (a.size.tensor_id, a.size.axis_id) 2869 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2870 assert ref_size is not None, ref_key 2871 assert not isinstance(ref_size, _DataDepSize), ref_key 2872 return a.size.get_size( 2873 axis=a, 2874 ref_axis=ref_axis, 2875 ref_size=ref_size, 2876 ) 2877 elif isinstance(a.size, DataDependentSize): 2878 if (t_descr.id, a.id) in ns: 2879 logger.warning( 2880 "Ignoring unexpected increment factor (n) for data dependent" 2881 + " size axis '{}' of tensor '{}'.", 2882 a.id, 2883 t_descr.id, 2884 ) 2885 return _DataDepSize(a.size.min, a.size.max) 2886 else: 2887 assert_never(a.size) 2888 2889 # first resolve all , but the `SizeReference` input sizes 2890 for t_descr in self.inputs: 2891 for a in t_descr.axes: 2892 if not isinstance(a.size, SizeReference): 2893 s = get_axis_size(a) 2894 assert not isinstance(s, _DataDepSize) 2895 inputs[t_descr.id, a.id] = s 2896 2897 # resolve all other input axis sizes 2898 for t_descr in self.inputs: 2899 for a in t_descr.axes: 2900 if isinstance(a.size, SizeReference): 2901 s = get_axis_size(a) 2902 assert not isinstance(s, _DataDepSize) 2903 inputs[t_descr.id, a.id] = s 2904 2905 # resolve all output axis sizes 2906 for t_descr in self.outputs: 2907 for a in t_descr.axes: 2908 assert not isinstance(a.size, ParameterizedSize) 2909 s = get_axis_size(a) 2910 outputs[t_descr.id, a.id] = s 2911 2912 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Inherited Members
3162def generate_covers( 3163 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3164 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3165) -> List[Path]: 3166 def squeeze( 3167 data: NDArray[Any], axes: Sequence[AnyAxis] 3168 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3169 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3170 if data.ndim != len(axes): 3171 raise ValueError( 3172 f"tensor shape {data.shape} does not match described axes" 3173 + f" {[a.id for a in axes]}" 3174 ) 3175 3176 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3177 return data.squeeze(), axes 3178 3179 def normalize( 3180 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3181 ) -> NDArray[np.float32]: 3182 data = data.astype("float32") 3183 data -= data.min(axis=axis, keepdims=True) 3184 data /= data.max(axis=axis, keepdims=True) + eps 3185 return data 3186 3187 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3188 original_shape = data.shape 3189 data, axes = squeeze(data, axes) 3190 3191 # take slice fom any batch or index axis if needed 3192 # and convert the first channel axis and take a slice from any additional channel axes 3193 slices: Tuple[slice, ...] = () 3194 ndim = data.ndim 3195 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3196 has_c_axis = False 3197 for i, a in enumerate(axes): 3198 s = data.shape[i] 3199 assert s > 1 3200 if ( 3201 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3202 and ndim > ndim_need 3203 ): 3204 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3205 ndim -= 1 3206 elif isinstance(a, ChannelAxis): 3207 if has_c_axis: 3208 # second channel axis 3209 data = data[slices + (slice(0, 1),)] 3210 ndim -= 1 3211 else: 3212 has_c_axis = True 3213 if s == 2: 3214 # visualize two channels with cyan and magenta 3215 data = np.concatenate( 3216 [ 3217 data[slices + (slice(1, 2),)], 3218 data[slices + (slice(0, 1),)], 3219 ( 3220 data[slices + (slice(0, 1),)] 3221 + data[slices + (slice(1, 2),)] 3222 ) 3223 / 2, # TODO: take maximum instead? 3224 ], 3225 axis=i, 3226 ) 3227 elif data.shape[i] == 3: 3228 pass # visualize 3 channels as RGB 3229 else: 3230 # visualize first 3 channels as RGB 3231 data = data[slices + (slice(3),)] 3232 3233 assert data.shape[i] == 3 3234 3235 slices += (slice(None),) 3236 3237 data, axes = squeeze(data, axes) 3238 assert len(axes) == ndim 3239 # take slice from z axis if needed 3240 slices = () 3241 if ndim > ndim_need: 3242 for i, a in enumerate(axes): 3243 s = data.shape[i] 3244 if a.id == AxisId("z"): 3245 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3246 data, axes = squeeze(data, axes) 3247 ndim -= 1 3248 break 3249 3250 slices += (slice(None),) 3251 3252 # take slice from any space or time axis 3253 slices = () 3254 3255 for i, a in enumerate(axes): 3256 if ndim <= ndim_need: 3257 break 3258 3259 s = data.shape[i] 3260 assert s > 1 3261 if isinstance( 3262 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3263 ): 3264 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3265 ndim -= 1 3266 3267 slices += (slice(None),) 3268 3269 del slices 3270 data, axes = squeeze(data, axes) 3271 assert len(axes) == ndim 3272 3273 if (has_c_axis and ndim != 3) or ndim != 2: 3274 raise ValueError( 3275 f"Failed to construct cover image from shape {original_shape}" 3276 ) 3277 3278 if not has_c_axis: 3279 assert ndim == 2 3280 data = np.repeat(data[:, :, None], 3, axis=2) 3281 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3282 ndim += 1 3283 3284 assert ndim == 3 3285 3286 # transpose axis order such that longest axis comes first... 3287 axis_order = list(np.argsort(list(data.shape))) 3288 axis_order.reverse() 3289 # ... and channel axis is last 3290 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3291 axis_order.append(axis_order.pop(c)) 3292 axes = [axes[ao] for ao in axis_order] 3293 data = data.transpose(axis_order) 3294 3295 # h, w = data.shape[:2] 3296 # if h / w in (1.0 or 2.0): 3297 # pass 3298 # elif h / w < 2: 3299 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3300 3301 norm_along = ( 3302 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3303 ) 3304 # normalize the data and map to 8 bit 3305 data = normalize(data, norm_along) 3306 data = (data * 255).astype("uint8") 3307 3308 return data 3309 3310 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3311 assert im0.dtype == im1.dtype == np.uint8 3312 assert im0.shape == im1.shape 3313 assert im0.ndim == 3 3314 N, M, C = im0.shape 3315 assert C == 3 3316 out = np.ones((N, M, C), dtype="uint8") 3317 for c in range(C): 3318 outc = np.tril(im0[..., c]) 3319 mask = outc == 0 3320 outc[mask] = np.triu(im1[..., c])[mask] 3321 out[..., c] = outc 3322 3323 return out 3324 3325 ipt_descr, ipt = inputs[0] 3326 out_descr, out = outputs[0] 3327 3328 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3329 out_img = to_2d_image(out, out_descr.axes) 3330 3331 cover_folder = Path(mkdtemp()) 3332 if ipt_img.shape == out_img.shape: 3333 covers = [cover_folder / "cover.png"] 3334 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3335 else: 3336 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3337 imwrite(covers[0], ipt_img) 3338 imwrite(covers[1], out_img) 3339 3340 return covers