bioimageio.spec.model.v0_4
1from __future__ import annotations 2 3import collections.abc 4from typing import ( 5 Any, 6 ClassVar, 7 Dict, 8 FrozenSet, 9 List, 10 Literal, 11 Optional, 12 Sequence, 13 Tuple, 14 Union, 15) 16 17import numpy as np 18from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 19from numpy.typing import NDArray 20from pydantic import ( 21 AllowInfNan, 22 Discriminator, 23 Field, 24 SerializationInfo, 25 SerializerFunctionWrapHandler, 26 TypeAdapter, 27 ValidationInfo, 28 WrapSerializer, 29 field_validator, 30 model_validator, 31) 32from typing_extensions import Annotated, LiteralString, Self, assert_never, get_args 33 34from .._internal.common_nodes import ( 35 KwargsNode, 36 Node, 37 NodeWithExplicitlySetFields, 38 StringNode, 39) 40from .._internal.constants import SHA256_HINT 41from .._internal.field_validation import validate_unique_entries 42from .._internal.field_warning import issue_warning, warn 43from .._internal.io import ( 44 BioimageioYamlContent, 45 WithSuffix, 46 download, 47 include_in_package_serializer, 48) 49from .._internal.io import FileDescr as FileDescr 50from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 51from .._internal.io_basics import Sha256 as Sha256 52from .._internal.io_utils import load_array 53from .._internal.packaging_context import packaging_context_var 54from .._internal.types import Datetime as Datetime 55from .._internal.types import Identifier as Identifier 56from .._internal.types import ImportantFileSource, LowerCaseIdentifier 57from .._internal.types import LicenseId as LicenseId 58from .._internal.types import NotEmpty as NotEmpty 59from .._internal.url import HttpUrl as HttpUrl 60from .._internal.validator_annotations import AfterValidator, RestrictCharacters 61from .._internal.version_type import Version as Version 62from .._internal.warning_levels import ALERT, INFO 63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 64from ..dataset.v0_2 import DatasetDescr as DatasetDescr 65from ..dataset.v0_2 import LinkedDataset as LinkedDataset 66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 67from ..generic.v0_2 import Author as Author 68from ..generic.v0_2 import BadgeDescr as BadgeDescr 69from ..generic.v0_2 import CiteEntry as CiteEntry 70from ..generic.v0_2 import Doi as Doi 71from ..generic.v0_2 import GenericModelDescrBase 72from ..generic.v0_2 import LinkedResource as LinkedResource 73from ..generic.v0_2 import Maintainer as Maintainer 74from ..generic.v0_2 import OrcidId as OrcidId 75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 76from ..generic.v0_2 import ResourceId as ResourceId 77from ..generic.v0_2 import Uploader as Uploader 78from ._v0_4_converter import convert_from_older_format 79 80 81class ModelId(ResourceId): 82 pass 83 84 85AxesStr = Annotated[ 86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 87] 88AxesInCZYX = Annotated[ 89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 90] 91 92PostprocessingName = Literal[ 93 "binarize", 94 "clip", 95 "scale_linear", 96 "sigmoid", 97 "zero_mean_unit_variance", 98 "scale_range", 99 "scale_mean_variance", 100] 101PreprocessingName = Literal[ 102 "binarize", 103 "clip", 104 "scale_linear", 105 "sigmoid", 106 "zero_mean_unit_variance", 107 "scale_range", 108] 109 110 111class TensorName(LowerCaseIdentifier): 112 pass 113 114 115class CallableFromDepencency(StringNode): 116 _pattern = r"^.+\..+$" 117 _submodule_adapter = TypeAdapter(Identifier) 118 119 module_name: str 120 121 @field_validator("module_name", mode="after") 122 def check_submodules(cls, module_name: str) -> str: 123 for submod in module_name.split("."): 124 _ = cls._submodule_adapter.validate_python(submod) 125 126 return module_name 127 128 callable_name: Identifier 129 130 @classmethod 131 def _get_data(cls, valid_string_data: str): 132 *mods, callname = valid_string_data.split(".") 133 return dict(module_name=".".join(mods), callable_name=callname) 134 135 136class CallableFromFile(StringNode): 137 _pattern = r"^.+:.+$" 138 source_file: Annotated[ 139 Union[RelativeFilePath, HttpUrl], 140 Field(union_mode="left_to_right"), 141 include_in_package_serializer, 142 ] 143 """∈📦 Python module that implements `callable_name`""" 144 callable_name: Identifier 145 """The Python identifier of """ 146 147 @classmethod 148 def _get_data(cls, valid_string_data: str): 149 *file_parts, callname = valid_string_data.split(":") 150 return dict(source_file=":".join(file_parts), callable_name=callname) 151 152 153CustomCallable = Annotated[ 154 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 155] 156 157 158class Dependencies(StringNode): 159 _pattern = r"^.+:.+$" 160 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 161 """Dependency manager""" 162 163 file: Annotated[ 164 ImportantFileSource, 165 Field(examples=["environment.yaml", "pom.xml", "requirements.txt"]), 166 ] 167 """∈📦 Dependency file""" 168 169 @classmethod 170 def _get_data(cls, valid_string_data: str): 171 manager, *file_parts = valid_string_data.split(":") 172 return dict(manager=manager, file=":".join(file_parts)) 173 174 175WeightsFormat = Literal[ 176 "keras_hdf5", 177 "onnx", 178 "pytorch_state_dict", 179 "tensorflow_js", 180 "tensorflow_saved_model_bundle", 181 "torchscript", 182] 183 184 185class WeightsDescr(Node): 186 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 187 onnx: Optional[OnnxWeightsDescr] = None 188 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 189 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 190 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 191 None 192 ) 193 torchscript: Optional[TorchscriptWeightsDescr] = None 194 195 @model_validator(mode="after") 196 def check_one_entry(self) -> Self: 197 if all( 198 entry is None 199 for entry in [ 200 self.keras_hdf5, 201 self.onnx, 202 self.pytorch_state_dict, 203 self.tensorflow_js, 204 self.tensorflow_saved_model_bundle, 205 self.torchscript, 206 ] 207 ): 208 raise ValueError("Missing weights entry") 209 210 return self 211 212 def __getitem__( 213 self, 214 key: WeightsFormat, 215 ): 216 if key == "keras_hdf5": 217 ret = self.keras_hdf5 218 elif key == "onnx": 219 ret = self.onnx 220 elif key == "pytorch_state_dict": 221 ret = self.pytorch_state_dict 222 elif key == "tensorflow_js": 223 ret = self.tensorflow_js 224 elif key == "tensorflow_saved_model_bundle": 225 ret = self.tensorflow_saved_model_bundle 226 elif key == "torchscript": 227 ret = self.torchscript 228 else: 229 raise KeyError(key) 230 231 if ret is None: 232 raise KeyError(key) 233 234 return ret 235 236 @property 237 def available_formats(self): 238 return { 239 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 240 **({} if self.onnx is None else {"onnx": self.onnx}), 241 **( 242 {} 243 if self.pytorch_state_dict is None 244 else {"pytorch_state_dict": self.pytorch_state_dict} 245 ), 246 **( 247 {} 248 if self.tensorflow_js is None 249 else {"tensorflow_js": self.tensorflow_js} 250 ), 251 **( 252 {} 253 if self.tensorflow_saved_model_bundle is None 254 else { 255 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 256 } 257 ), 258 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 259 } 260 261 @property 262 def missing_formats(self): 263 return { 264 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 265 } 266 267 268class WeightsEntryDescrBase(FileDescr): 269 type: ClassVar[WeightsFormat] 270 weights_format_name: ClassVar[str] # human readable 271 272 source: ImportantFileSource 273 """∈📦 The weights file.""" 274 275 attachments: Annotated[ 276 Union[AttachmentsDescr, None], 277 warn(None, "Weights entry depends on additional attachments.", ALERT), 278 ] = None 279 """Attachments that are specific to this weights entry.""" 280 281 authors: Union[List[Author], None] = None 282 """Authors 283 Either the person(s) that have trained this model resulting in the original weights file. 284 (If this is the initial weights entry, i.e. it does not have a `parent`) 285 Or the person(s) who have converted the weights to this weights format. 286 (If this is a child weight, i.e. it has a `parent` field) 287 """ 288 289 dependencies: Annotated[ 290 Optional[Dependencies], 291 warn( 292 None, 293 "Custom dependencies ({value}) specified. Avoid this whenever possible " 294 + "to allow execution in a wider range of software environments.", 295 ), 296 Field( 297 examples=[ 298 "conda:environment.yaml", 299 "maven:./pom.xml", 300 "pip:./requirements.txt", 301 ] 302 ), 303 ] = None 304 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 305 306 parent: Annotated[ 307 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 308 ] = None 309 """The source weights these weights were converted from. 310 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 311 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 312 All weight entries except one (the initial set of weights resulting from training the model), 313 need to have this field.""" 314 315 @model_validator(mode="after") 316 def check_parent_is_not_self(self) -> Self: 317 if self.type == self.parent: 318 raise ValueError("Weights entry can't be it's own parent.") 319 320 return self 321 322 323class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 324 type = "keras_hdf5" 325 weights_format_name: ClassVar[str] = "Keras HDF5" 326 tensorflow_version: Optional[Version] = None 327 """TensorFlow version used to create these weights""" 328 329 @field_validator("tensorflow_version", mode="after") 330 @classmethod 331 def _tfv(cls, value: Any): 332 if value is None: 333 issue_warning( 334 "missing. Please specify the TensorFlow version" 335 + " these weights were created with.", 336 value=value, 337 severity=ALERT, 338 field="tensorflow_version", 339 ) 340 return value 341 342 343class OnnxWeightsDescr(WeightsEntryDescrBase): 344 type = "onnx" 345 weights_format_name: ClassVar[str] = "ONNX" 346 opset_version: Optional[Annotated[int, Ge(7)]] = None 347 """ONNX opset version""" 348 349 @field_validator("opset_version", mode="after") 350 @classmethod 351 def _ov(cls, value: Any): 352 if value is None: 353 issue_warning( 354 "Missing ONNX opset version (aka ONNX opset number). " 355 + "Please specify the ONNX opset version these weights were created" 356 + " with.", 357 value=value, 358 severity=ALERT, 359 field="opset_version", 360 ) 361 return value 362 363 364class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 365 type = "pytorch_state_dict" 366 weights_format_name: ClassVar[str] = "Pytorch State Dict" 367 architecture: CustomCallable = Field( 368 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 369 ) 370 """callable returning a torch.nn.Module instance. 371 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 372 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 373 374 architecture_sha256: Annotated[ 375 Optional[Sha256], 376 Field( 377 description=( 378 "The SHA256 of the architecture source file, if the architecture is not" 379 " defined in a module listed in `dependencies`\n" 380 ) 381 + SHA256_HINT, 382 ), 383 ] = None 384 """The SHA256 of the architecture source file, 385 if the architecture is not defined in a module listed in `dependencies`""" 386 387 @model_validator(mode="after") 388 def check_architecture_sha256(self) -> Self: 389 if isinstance(self.architecture, CallableFromFile): 390 if self.architecture_sha256 is None: 391 raise ValueError( 392 "Missing required `architecture_sha256` for `architecture` with" 393 + " source file." 394 ) 395 elif self.architecture_sha256 is not None: 396 raise ValueError( 397 "Got `architecture_sha256` for architecture that does not have a source" 398 + " file." 399 ) 400 401 return self 402 403 kwargs: Dict[str, Any] = Field(default_factory=dict) 404 """key word arguments for the `architecture` callable""" 405 406 pytorch_version: Optional[Version] = None 407 """Version of the PyTorch library used. 408 If `depencencies` is specified it should include pytorch and the verison has to match. 409 (`dependencies` overrules `pytorch_version`)""" 410 411 @field_validator("pytorch_version", mode="after") 412 @classmethod 413 def _ptv(cls, value: Any): 414 if value is None: 415 issue_warning( 416 "missing. Please specify the PyTorch version these" 417 + " PyTorch state dict weights were created with.", 418 value=value, 419 severity=ALERT, 420 field="pytorch_version", 421 ) 422 return value 423 424 425class TorchscriptWeightsDescr(WeightsEntryDescrBase): 426 type = "torchscript" 427 weights_format_name: ClassVar[str] = "TorchScript" 428 pytorch_version: Optional[Version] = None 429 """Version of the PyTorch library used.""" 430 431 @field_validator("pytorch_version", mode="after") 432 @classmethod 433 def _ptv(cls, value: Any): 434 if value is None: 435 issue_warning( 436 "missing. Please specify the PyTorch version these" 437 + " Torchscript weights were created with.", 438 value=value, 439 severity=ALERT, 440 field="pytorch_version", 441 ) 442 return value 443 444 445class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 446 type = "tensorflow_js" 447 weights_format_name: ClassVar[str] = "Tensorflow.js" 448 tensorflow_version: Optional[Version] = None 449 """Version of the TensorFlow library used.""" 450 451 @field_validator("tensorflow_version", mode="after") 452 @classmethod 453 def _tfv(cls, value: Any): 454 if value is None: 455 issue_warning( 456 "missing. Please specify the TensorFlow version" 457 + " these TensorflowJs weights were created with.", 458 value=value, 459 severity=ALERT, 460 field="tensorflow_version", 461 ) 462 return value 463 464 source: ImportantFileSource 465 """∈📦 The multi-file weights. 466 All required files/folders should be a zip archive.""" 467 468 469class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 470 type = "tensorflow_saved_model_bundle" 471 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 472 tensorflow_version: Optional[Version] = None 473 """Version of the TensorFlow library used.""" 474 475 @field_validator("tensorflow_version", mode="after") 476 @classmethod 477 def _tfv(cls, value: Any): 478 if value is None: 479 issue_warning( 480 "missing. Please specify the TensorFlow version" 481 + " these Tensorflow saved model bundle weights were created with.", 482 value=value, 483 severity=ALERT, 484 field="tensorflow_version", 485 ) 486 return value 487 488 489class ParameterizedInputShape(Node): 490 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 491 492 min: NotEmpty[List[int]] 493 """The minimum input shape""" 494 495 step: NotEmpty[List[int]] 496 """The minimum shape change""" 497 498 def __len__(self) -> int: 499 return len(self.min) 500 501 @model_validator(mode="after") 502 def matching_lengths(self) -> Self: 503 if len(self.min) != len(self.step): 504 raise ValueError("`min` and `step` required to have the same length") 505 506 return self 507 508 509class ImplicitOutputShape(Node): 510 """Output tensor shape depending on an input tensor shape. 511 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 512 513 reference_tensor: TensorName 514 """Name of the reference tensor.""" 515 516 scale: NotEmpty[List[Optional[float]]] 517 """output_pix/input_pix for each dimension. 518 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 519 520 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 521 """Position of origin wrt to input.""" 522 523 def __len__(self) -> int: 524 return len(self.scale) 525 526 @model_validator(mode="after") 527 def matching_lengths(self) -> Self: 528 if len(self.scale) != len(self.offset): 529 raise ValueError( 530 f"scale {self.scale} has to have same length as offset {self.offset}!" 531 ) 532 # if we have an expanded dimension, make sure that it's offet is not zero 533 for sc, off in zip(self.scale, self.offset): 534 if sc is None and not off: 535 raise ValueError("`offset` must not be zero if `scale` is none/zero") 536 537 return self 538 539 540class TensorDescrBase(Node): 541 name: TensorName 542 """Tensor name. No duplicates are allowed.""" 543 544 description: str = "" 545 546 axes: AxesStr 547 """Axes identifying characters. Same length and order as the axes in `shape`. 548 | axis | description | 549 | --- | --- | 550 | b | batch (groups multiple samples) | 551 | i | instance/index/element | 552 | t | time | 553 | c | channel | 554 | z | spatial dimension z | 555 | y | spatial dimension y | 556 | x | spatial dimension x | 557 """ 558 559 data_range: Optional[ 560 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 561 ] = None 562 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 563 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 564 565 566class ProcessingKwargs(KwargsNode): 567 """base class for pre-/postprocessing key word arguments""" 568 569 570class ProcessingDescrBase(NodeWithExplicitlySetFields): 571 """processing base class""" 572 573 # name: Literal[PreprocessingName, PostprocessingName] # todo: make abstract field 574 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"name"}) 575 576 577class BinarizeKwargs(ProcessingKwargs): 578 """key word arguments for `BinarizeDescr`""" 579 580 threshold: float 581 """The fixed threshold""" 582 583 584class BinarizeDescr(ProcessingDescrBase): 585 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 586 Values above the threshold will be set to one, values below the threshold to zero. 587 """ 588 589 name: Literal["binarize"] = "binarize" 590 kwargs: BinarizeKwargs 591 592 593class ClipKwargs(ProcessingKwargs): 594 """key word arguments for `ClipDescr`""" 595 596 min: float 597 """minimum value for clipping""" 598 max: float 599 """maximum value for clipping""" 600 601 602class ClipDescr(ProcessingDescrBase): 603 """Clip tensor values to a range. 604 605 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 606 and above `ClipKwargs.max` to `ClipKwargs.max`. 607 """ 608 609 name: Literal["clip"] = "clip" 610 611 kwargs: ClipKwargs 612 613 614class ScaleLinearKwargs(ProcessingKwargs): 615 """key word arguments for `ScaleLinearDescr`""" 616 617 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 618 """The subset of axes to scale jointly. 619 For example xy to scale the two image axes for 2d data jointly.""" 620 621 gain: Union[float, List[float]] = 1.0 622 """multiplicative factor""" 623 624 offset: Union[float, List[float]] = 0.0 625 """additive term""" 626 627 @model_validator(mode="after") 628 def either_gain_or_offset(self) -> Self: 629 if ( 630 self.gain == 1.0 631 or isinstance(self.gain, list) 632 and all(g == 1.0 for g in self.gain) 633 ) and ( 634 self.offset == 0.0 635 or isinstance(self.offset, list) 636 and all(off == 0.0 for off in self.offset) 637 ): 638 raise ValueError( 639 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 640 + " 0.0." 641 ) 642 643 return self 644 645 646class ScaleLinearDescr(ProcessingDescrBase): 647 """Fixed linear scaling.""" 648 649 name: Literal["scale_linear"] = "scale_linear" 650 kwargs: ScaleLinearKwargs 651 652 653class SigmoidDescr(ProcessingDescrBase): 654 """The logistic sigmoid funciton, a.k.a. expit function.""" 655 656 name: Literal["sigmoid"] = "sigmoid" 657 658 @property 659 def kwargs(self) -> ProcessingKwargs: 660 """empty kwargs""" 661 return ProcessingKwargs() 662 663 664class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 665 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 666 667 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 668 """Mode for computing mean and variance. 669 | mode | description | 670 | ----------- | ------------------------------------ | 671 | fixed | Fixed values for mean and variance | 672 | per_dataset | Compute for the entire dataset | 673 | per_sample | Compute for each sample individually | 674 """ 675 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 676 """The subset of axes to normalize jointly. 677 For example `xy` to normalize the two image axes for 2d data jointly.""" 678 679 mean: Annotated[ 680 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 681 ] = None 682 """The mean value(s) to use for `mode: fixed`. 683 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 684 # todo: check if means match input axes (for mode 'fixed') 685 686 std: Annotated[ 687 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 688 ] = None 689 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 690 691 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 692 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 693 694 @model_validator(mode="after") 695 def mean_and_std_match_mode(self) -> Self: 696 if self.mode == "fixed" and (self.mean is None or self.std is None): 697 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 698 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 699 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 700 701 return self 702 703 704class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 705 """Subtract mean and divide by variance.""" 706 707 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 708 kwargs: ZeroMeanUnitVarianceKwargs 709 710 711class ScaleRangeKwargs(ProcessingKwargs): 712 """key word arguments for `ScaleRangeDescr` 713 714 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 715 this processing step normalizes data to the [0, 1] intervall. 716 For other percentiles the normalized values will partially be outside the [0, 1] 717 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 718 normalized values to a range. 719 """ 720 721 mode: Literal["per_dataset", "per_sample"] 722 """Mode for computing percentiles. 723 | mode | description | 724 | ----------- | ------------------------------------ | 725 | per_dataset | compute for the entire dataset | 726 | per_sample | compute for each sample individually | 727 """ 728 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 729 """The subset of axes to normalize jointly. 730 For example xy to normalize the two image axes for 2d data jointly.""" 731 732 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 733 """The lower percentile used to determine the value to align with zero.""" 734 735 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 736 """The upper percentile used to determine the value to align with one. 737 Has to be bigger than `min_percentile`. 738 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 739 accepting percentiles specified in the range 0.0 to 1.0.""" 740 741 @model_validator(mode="after") 742 def min_smaller_max(self, info: ValidationInfo) -> Self: 743 if self.min_percentile >= self.max_percentile: 744 raise ValueError( 745 f"min_percentile {self.min_percentile} >= max_percentile" 746 + f" {self.max_percentile}" 747 ) 748 749 return self 750 751 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 752 """Epsilon for numeric stability. 753 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 754 with `v_lower,v_upper` values at the respective percentiles.""" 755 756 reference_tensor: Optional[TensorName] = None 757 """Tensor name to compute the percentiles from. Default: The tensor itself. 758 For any tensor in `inputs` only input tensor references are allowed. 759 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 760 761 762class ScaleRangeDescr(ProcessingDescrBase): 763 """Scale with percentiles.""" 764 765 name: Literal["scale_range"] = "scale_range" 766 kwargs: ScaleRangeKwargs 767 768 769class ScaleMeanVarianceKwargs(ProcessingKwargs): 770 """key word arguments for `ScaleMeanVarianceDescr`""" 771 772 mode: Literal["per_dataset", "per_sample"] 773 """Mode for computing mean and variance. 774 | mode | description | 775 | ----------- | ------------------------------------ | 776 | per_dataset | Compute for the entire dataset | 777 | per_sample | Compute for each sample individually | 778 """ 779 780 reference_tensor: TensorName 781 """Name of tensor to match.""" 782 783 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 784 """The subset of axes to scale jointly. 785 For example xy to normalize the two image axes for 2d data jointly. 786 Default: scale all non-batch axes jointly.""" 787 788 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 789 """Epsilon for numeric stability: 790 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 791 792 793class ScaleMeanVarianceDescr(ProcessingDescrBase): 794 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 795 796 name: Literal["scale_mean_variance"] = "scale_mean_variance" 797 kwargs: ScaleMeanVarianceKwargs 798 799 800PreprocessingDescr = Annotated[ 801 Union[ 802 BinarizeDescr, 803 ClipDescr, 804 ScaleLinearDescr, 805 SigmoidDescr, 806 ZeroMeanUnitVarianceDescr, 807 ScaleRangeDescr, 808 ], 809 Discriminator("name"), 810] 811PostprocessingDescr = Annotated[ 812 Union[ 813 BinarizeDescr, 814 ClipDescr, 815 ScaleLinearDescr, 816 SigmoidDescr, 817 ZeroMeanUnitVarianceDescr, 818 ScaleRangeDescr, 819 ScaleMeanVarianceDescr, 820 ], 821 Discriminator("name"), 822] 823 824 825class InputTensorDescr(TensorDescrBase): 826 data_type: Literal["float32", "uint8", "uint16"] 827 """For now an input tensor is expected to be given as `float32`. 828 The data flow in bioimage.io models is explained 829 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 830 831 shape: Annotated[ 832 Union[Sequence[int], ParameterizedInputShape], 833 Field( 834 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 835 ), 836 ] 837 """Specification of input tensor shape.""" 838 839 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 840 """Description of how this input should be preprocessed.""" 841 842 @model_validator(mode="after") 843 def zero_batch_step_and_one_batch_size(self) -> Self: 844 bidx = self.axes.find("b") 845 if bidx == -1: 846 return self 847 848 if isinstance(self.shape, ParameterizedInputShape): 849 step = self.shape.step 850 shape = self.shape.min 851 if step[bidx] != 0: 852 raise ValueError( 853 "Input shape step has to be zero in the batch dimension (the batch" 854 + " dimension can always be increased, but `step` should specify how" 855 + " to increase the minimal shape to find the largest single batch" 856 + " shape)" 857 ) 858 else: 859 shape = self.shape 860 861 if shape[bidx] != 1: 862 raise ValueError("Input shape has to be 1 in the batch dimension b.") 863 864 return self 865 866 @model_validator(mode="after") 867 def validate_preprocessing_kwargs(self) -> Self: 868 for p in self.preprocessing: 869 kwargs_axes = p.kwargs.get("axes", "") 870 if not isinstance(kwargs_axes, str): 871 raise ValueError( 872 f"Expected an `axes` string, but got {type(kwargs_axes)}" 873 ) 874 875 if any(a not in self.axes for a in kwargs_axes): 876 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 877 878 return self 879 880 881class OutputTensorDescr(TensorDescrBase): 882 data_type: Literal[ 883 "float32", 884 "float64", 885 "uint8", 886 "int8", 887 "uint16", 888 "int16", 889 "uint32", 890 "int32", 891 "uint64", 892 "int64", 893 "bool", 894 ] 895 """Data type. 896 The data flow in bioimage.io models is explained 897 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 898 899 shape: Union[Sequence[int], ImplicitOutputShape] 900 """Output tensor shape.""" 901 902 halo: Optional[Sequence[int]] = None 903 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 904 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 905 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 906 907 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 908 """Description of how this output should be postprocessed.""" 909 910 @model_validator(mode="after") 911 def matching_halo_length(self) -> Self: 912 if self.halo and len(self.halo) != len(self.shape): 913 raise ValueError( 914 f"halo {self.halo} has to have same length as shape {self.shape}!" 915 ) 916 917 return self 918 919 @model_validator(mode="after") 920 def validate_postprocessing_kwargs(self) -> Self: 921 for p in self.postprocessing: 922 kwargs_axes = p.kwargs.get("axes", "") 923 if not isinstance(kwargs_axes, str): 924 raise ValueError(f"Expected {kwargs_axes} to be a string") 925 926 if any(a not in self.axes for a in kwargs_axes): 927 raise ValueError("`kwargs.axes` needs to be subset of axes") 928 929 return self 930 931 932KnownRunMode = Literal["deepimagej"] 933 934 935class RunMode(Node): 936 name: Annotated[ 937 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 938 ] 939 """Run mode name""" 940 941 kwargs: Dict[str, Any] = Field(default_factory=dict) 942 """Run mode specific key word arguments""" 943 944 945class LinkedModel(Node): 946 """Reference to a bioimage.io model.""" 947 948 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 949 """A valid model `id` from the bioimage.io collection.""" 950 951 version_number: Optional[int] = None 952 """version number (n-th published version, not the semantic version) of linked model""" 953 954 955def package_weights( 956 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 957 handler: SerializerFunctionWrapHandler, 958 info: SerializationInfo, 959): 960 ctxt = packaging_context_var.get() 961 if ctxt is not None and ctxt.weights_priority_order is not None: 962 for wf in ctxt.weights_priority_order: 963 w = getattr(value, wf, None) 964 if w is not None: 965 break 966 else: 967 raise ValueError( 968 "None of the weight formats in `weights_priority_order`" 969 + f" ({ctxt.weights_priority_order}) is present in the given model." 970 ) 971 972 assert isinstance(w, Node), type(w) 973 # construct WeightsDescr with new single weight format entry 974 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 975 value = value.model_construct(None, **{wf: new_w}) 976 977 return handler( 978 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 979 ) 980 981 982class ModelDescr(GenericModelDescrBase): 983 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 984 985 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 986 """ 987 988 format_version: Literal["0.4.10",] = "0.4.10" 989 """Version of the bioimage.io model description specification used. 990 When creating a new model always use the latest micro/patch version described here. 991 The `format_version` is important for any consumer software to understand how to parse the fields. 992 """ 993 994 type: Literal["model"] = "model" 995 """Specialized resource type 'model'""" 996 997 id: Optional[ModelId] = None 998 """bioimage.io-wide unique resource identifier 999 assigned by bioimage.io; version **un**specific.""" 1000 1001 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 1002 List[Author] 1003 ] 1004 """The authors are the creators of the model RDF and the primary points of contact.""" 1005 1006 documentation: Annotated[ 1007 ImportantFileSource, 1008 Field( 1009 examples=[ 1010 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 1011 "README.md", 1012 ], 1013 ), 1014 ] 1015 """∈📦 URL or relative path to a markdown file with additional documentation. 1016 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 1017 The documentation should include a '[#[#]]# Validation' (sub)section 1018 with details on how to quantitatively validate the model on unseen data.""" 1019 1020 inputs: NotEmpty[List[InputTensorDescr]] 1021 """Describes the input tensors expected by this model.""" 1022 1023 license: Annotated[ 1024 Union[LicenseId, str], 1025 warn(LicenseId, "Unknown license id '{value}'."), 1026 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 1027 ] 1028 """A [SPDX license identifier](https://spdx.org/licenses/). 1029 We do notsupport custom license beyond the SPDX license list, if you need that please 1030 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 1031 ) to discuss your intentions with the community.""" 1032 1033 name: Annotated[ 1034 str, 1035 MinLen(1), 1036 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 1037 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 1038 ] 1039 """A human-readable name of this model. 1040 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 1041 1042 outputs: NotEmpty[List[OutputTensorDescr]] 1043 """Describes the output tensors.""" 1044 1045 @field_validator("inputs", "outputs") 1046 @classmethod 1047 def unique_tensor_descr_names( 1048 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1049 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1050 unique_names = {str(v.name) for v in value} 1051 if len(unique_names) != len(value): 1052 raise ValueError("Duplicate tensor descriptor names") 1053 1054 return value 1055 1056 @model_validator(mode="after") 1057 def unique_io_names(self) -> Self: 1058 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1059 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1060 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1061 1062 return self 1063 1064 @model_validator(mode="after") 1065 def minimum_shape2valid_output(self) -> Self: 1066 tensors_by_name: Dict[ 1067 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1068 ] = {t.name: t for t in self.inputs + self.outputs} 1069 1070 for out in self.outputs: 1071 if isinstance(out.shape, ImplicitOutputShape): 1072 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1073 ndim_out_ref = len( 1074 [scale for scale in out.shape.scale if scale is not None] 1075 ) 1076 if ndim_ref != ndim_out_ref: 1077 expanded_dim_note = ( 1078 " Note that expanded dimensions (`scale`: null) are not" 1079 + f" counted for {out.name}'sdimensionality here." 1080 if None in out.shape.scale 1081 else "" 1082 ) 1083 raise ValueError( 1084 f"Referenced tensor '{out.shape.reference_tensor}' with" 1085 + f" {ndim_ref} dimensions does not match output tensor" 1086 + f" '{out.name}' with" 1087 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1088 ) 1089 1090 min_out_shape = self._get_min_shape(out, tensors_by_name) 1091 if out.halo: 1092 halo = out.halo 1093 halo_msg = f" for halo {out.halo}" 1094 else: 1095 halo = [0] * len(min_out_shape) 1096 halo_msg = "" 1097 1098 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1099 raise ValueError( 1100 f"Minimal shape {min_out_shape} of output {out.name} is too" 1101 + f" small{halo_msg}." 1102 ) 1103 1104 return self 1105 1106 @classmethod 1107 def _get_min_shape( 1108 cls, 1109 t: Union[InputTensorDescr, OutputTensorDescr], 1110 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1111 ) -> Sequence[int]: 1112 """output with subtracted halo has to result in meaningful output even for the minimal input 1113 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1114 """ 1115 if isinstance(t.shape, collections.abc.Sequence): 1116 return t.shape 1117 elif isinstance(t.shape, ParameterizedInputShape): 1118 return t.shape.min 1119 elif isinstance(t.shape, ImplicitOutputShape): 1120 pass 1121 else: 1122 assert_never(t.shape) 1123 1124 ref_shape = cls._get_min_shape( 1125 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1126 ) 1127 1128 if None not in t.shape.scale: 1129 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1130 else: 1131 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1132 new_ref_shape: List[int] = [] 1133 for idx in range(len(t.shape.scale)): 1134 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1135 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1136 1137 ref_shape = new_ref_shape 1138 assert len(ref_shape) == len(t.shape.scale) 1139 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1140 1141 offset = t.shape.offset 1142 assert len(offset) == len(scale) 1143 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1144 1145 @model_validator(mode="after") 1146 def validate_tensor_references_in_inputs(self) -> Self: 1147 for t in self.inputs: 1148 for proc in t.preprocessing: 1149 if "reference_tensor" not in proc.kwargs: 1150 continue 1151 1152 ref_tensor = proc.kwargs["reference_tensor"] 1153 if ref_tensor is not None and str(ref_tensor) not in { 1154 str(t.name) for t in self.inputs 1155 }: 1156 raise ValueError(f"'{ref_tensor}' not found in inputs") 1157 1158 if ref_tensor == t.name: 1159 raise ValueError( 1160 f"invalid self reference for preprocessing of tensor {t.name}" 1161 ) 1162 1163 return self 1164 1165 @model_validator(mode="after") 1166 def validate_tensor_references_in_outputs(self) -> Self: 1167 for t in self.outputs: 1168 for proc in t.postprocessing: 1169 if "reference_tensor" not in proc.kwargs: 1170 continue 1171 ref_tensor = proc.kwargs["reference_tensor"] 1172 if ref_tensor is not None and str(ref_tensor) not in { 1173 str(t.name) for t in self.inputs 1174 }: 1175 raise ValueError(f"{ref_tensor} not found in inputs") 1176 1177 return self 1178 1179 packaged_by: List[Author] = Field(default_factory=list) 1180 """The persons that have packaged and uploaded this model. 1181 Only required if those persons differ from the `authors`.""" 1182 1183 parent: Optional[LinkedModel] = None 1184 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1185 1186 @field_validator("parent", mode="before") 1187 @classmethod 1188 def ignore_url_parent(cls, parent: Any): 1189 if isinstance(parent, dict): 1190 return None 1191 1192 else: 1193 return parent 1194 1195 run_mode: Optional[RunMode] = None 1196 """Custom run mode for this model: for more complex prediction procedures like test time 1197 data augmentation that currently cannot be expressed in the specification. 1198 No standard run modes are defined yet.""" 1199 1200 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1201 """∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1202 for example stored as PNG or TIFF images. 1203 The sample files primarily serve to inform a human user about an example use case""" 1204 1205 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1206 """∈📦 URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1207 1208 test_inputs: NotEmpty[ 1209 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1210 ] 1211 """∈📦 Test input tensors compatible with the `inputs` description for a **single test case**. 1212 This means if your model has more than one input, you should provide one URL/relative path for each input. 1213 Each test input should be a file with an ndarray in 1214 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1215 The extension must be '.npy'.""" 1216 1217 test_outputs: NotEmpty[ 1218 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1219 ] 1220 """∈📦 Analog to `test_inputs`.""" 1221 1222 timestamp: Datetime 1223 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1224 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1225 1226 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1227 """The dataset used to train this model""" 1228 1229 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1230 """The weights for this model. 1231 Weights can be given for different formats, but should otherwise be equivalent. 1232 The available weight formats determine which consumers can use this model.""" 1233 1234 @model_validator(mode="before") 1235 @classmethod 1236 def _convert_from_older_format( 1237 cls, data: BioimageioYamlContent, / 1238 ) -> BioimageioYamlContent: 1239 convert_from_older_format(data) 1240 return data 1241 1242 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1243 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1244 assert all(isinstance(d, np.ndarray) for d in data) 1245 return data 1246 1247 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1248 data = [load_array(download(out).path) for out in self.test_outputs] 1249 assert all(isinstance(d, np.ndarray) for d in data) 1250 return data
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
116class CallableFromDepencency(StringNode): 117 _pattern = r"^.+\..+$" 118 _submodule_adapter = TypeAdapter(Identifier) 119 120 module_name: str 121 122 @field_validator("module_name", mode="after") 123 def check_submodules(cls, module_name: str) -> str: 124 for submod in module_name.split("."): 125 _ = cls._submodule_adapter.validate_python(submod) 126 127 return module_name 128 129 callable_name: Identifier 130 131 @classmethod 132 def _get_data(cls, valid_string_data: str): 133 *mods, callname = valid_string_data.split(".") 134 return dict(module_name=".".join(mods), callable_name=callname)
deprecated! don't use for new spec fields!
Inherited Members
137class CallableFromFile(StringNode): 138 _pattern = r"^.+:.+$" 139 source_file: Annotated[ 140 Union[RelativeFilePath, HttpUrl], 141 Field(union_mode="left_to_right"), 142 include_in_package_serializer, 143 ] 144 """∈📦 Python module that implements `callable_name`""" 145 callable_name: Identifier 146 """The Python identifier of """ 147 148 @classmethod 149 def _get_data(cls, valid_string_data: str): 150 *file_parts, callname = valid_string_data.split(":") 151 return dict(source_file=":".join(file_parts), callable_name=callname)
deprecated! don't use for new spec fields!
∈📦 Python module that implements callable_name
Inherited Members
159class Dependencies(StringNode): 160 _pattern = r"^.+:.+$" 161 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 162 """Dependency manager""" 163 164 file: Annotated[ 165 ImportantFileSource, 166 Field(examples=["environment.yaml", "pom.xml", "requirements.txt"]), 167 ] 168 """∈📦 Dependency file""" 169 170 @classmethod 171 def _get_data(cls, valid_string_data: str): 172 manager, *file_parts = valid_string_data.split(":") 173 return dict(manager=manager, file=":".join(file_parts))
deprecated! don't use for new spec fields!
Dependency manager
∈📦 Dependency file
Inherited Members
186class WeightsDescr(Node): 187 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 188 onnx: Optional[OnnxWeightsDescr] = None 189 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 190 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 191 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 192 None 193 ) 194 torchscript: Optional[TorchscriptWeightsDescr] = None 195 196 @model_validator(mode="after") 197 def check_one_entry(self) -> Self: 198 if all( 199 entry is None 200 for entry in [ 201 self.keras_hdf5, 202 self.onnx, 203 self.pytorch_state_dict, 204 self.tensorflow_js, 205 self.tensorflow_saved_model_bundle, 206 self.torchscript, 207 ] 208 ): 209 raise ValueError("Missing weights entry") 210 211 return self 212 213 def __getitem__( 214 self, 215 key: WeightsFormat, 216 ): 217 if key == "keras_hdf5": 218 ret = self.keras_hdf5 219 elif key == "onnx": 220 ret = self.onnx 221 elif key == "pytorch_state_dict": 222 ret = self.pytorch_state_dict 223 elif key == "tensorflow_js": 224 ret = self.tensorflow_js 225 elif key == "tensorflow_saved_model_bundle": 226 ret = self.tensorflow_saved_model_bundle 227 elif key == "torchscript": 228 ret = self.torchscript 229 else: 230 raise KeyError(key) 231 232 if ret is None: 233 raise KeyError(key) 234 235 return ret 236 237 @property 238 def available_formats(self): 239 return { 240 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 241 **({} if self.onnx is None else {"onnx": self.onnx}), 242 **( 243 {} 244 if self.pytorch_state_dict is None 245 else {"pytorch_state_dict": self.pytorch_state_dict} 246 ), 247 **( 248 {} 249 if self.tensorflow_js is None 250 else {"tensorflow_js": self.tensorflow_js} 251 ), 252 **( 253 {} 254 if self.tensorflow_saved_model_bundle is None 255 else { 256 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 257 } 258 ), 259 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 260 } 261 262 @property 263 def missing_formats(self): 264 return { 265 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 266 }
196 @model_validator(mode="after") 197 def check_one_entry(self) -> Self: 198 if all( 199 entry is None 200 for entry in [ 201 self.keras_hdf5, 202 self.onnx, 203 self.pytorch_state_dict, 204 self.tensorflow_js, 205 self.tensorflow_saved_model_bundle, 206 self.torchscript, 207 ] 208 ): 209 raise ValueError("Missing weights entry") 210 211 return self
237 @property 238 def available_formats(self): 239 return { 240 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 241 **({} if self.onnx is None else {"onnx": self.onnx}), 242 **( 243 {} 244 if self.pytorch_state_dict is None 245 else {"pytorch_state_dict": self.pytorch_state_dict} 246 ), 247 **( 248 {} 249 if self.tensorflow_js is None 250 else {"tensorflow_js": self.tensorflow_js} 251 ), 252 **( 253 {} 254 if self.tensorflow_saved_model_bundle is None 255 else { 256 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 257 } 258 ), 259 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 260 }
Inherited Members
269class WeightsEntryDescrBase(FileDescr): 270 type: ClassVar[WeightsFormat] 271 weights_format_name: ClassVar[str] # human readable 272 273 source: ImportantFileSource 274 """∈📦 The weights file.""" 275 276 attachments: Annotated[ 277 Union[AttachmentsDescr, None], 278 warn(None, "Weights entry depends on additional attachments.", ALERT), 279 ] = None 280 """Attachments that are specific to this weights entry.""" 281 282 authors: Union[List[Author], None] = None 283 """Authors 284 Either the person(s) that have trained this model resulting in the original weights file. 285 (If this is the initial weights entry, i.e. it does not have a `parent`) 286 Or the person(s) who have converted the weights to this weights format. 287 (If this is a child weight, i.e. it has a `parent` field) 288 """ 289 290 dependencies: Annotated[ 291 Optional[Dependencies], 292 warn( 293 None, 294 "Custom dependencies ({value}) specified. Avoid this whenever possible " 295 + "to allow execution in a wider range of software environments.", 296 ), 297 Field( 298 examples=[ 299 "conda:environment.yaml", 300 "maven:./pom.xml", 301 "pip:./requirements.txt", 302 ] 303 ), 304 ] = None 305 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 306 307 parent: Annotated[ 308 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 309 ] = None 310 """The source weights these weights were converted from. 311 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 312 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 313 All weight entries except one (the initial set of weights resulting from training the model), 314 need to have this field.""" 315 316 @model_validator(mode="after") 317 def check_parent_is_not_self(self) -> Self: 318 if self.type == self.parent: 319 raise ValueError("Weights entry can't be it's own parent.") 320 321 return self
∈📦 The weights file.
Attachments that are specific to this weights entry.
Dependency manager and dependency file, specified as <dependency manager>:<relative file path>
.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
324class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 325 type = "keras_hdf5" 326 weights_format_name: ClassVar[str] = "Keras HDF5" 327 tensorflow_version: Optional[Version] = None 328 """TensorFlow version used to create these weights""" 329 330 @field_validator("tensorflow_version", mode="after") 331 @classmethod 332 def _tfv(cls, value: Any): 333 if value is None: 334 issue_warning( 335 "missing. Please specify the TensorFlow version" 336 + " these weights were created with.", 337 value=value, 338 severity=ALERT, 339 field="tensorflow_version", 340 ) 341 return value
TensorFlow version used to create these weights
344class OnnxWeightsDescr(WeightsEntryDescrBase): 345 type = "onnx" 346 weights_format_name: ClassVar[str] = "ONNX" 347 opset_version: Optional[Annotated[int, Ge(7)]] = None 348 """ONNX opset version""" 349 350 @field_validator("opset_version", mode="after") 351 @classmethod 352 def _ov(cls, value: Any): 353 if value is None: 354 issue_warning( 355 "Missing ONNX opset version (aka ONNX opset number). " 356 + "Please specify the ONNX opset version these weights were created" 357 + " with.", 358 value=value, 359 severity=ALERT, 360 field="opset_version", 361 ) 362 return value
365class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 366 type = "pytorch_state_dict" 367 weights_format_name: ClassVar[str] = "Pytorch State Dict" 368 architecture: CustomCallable = Field( 369 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 370 ) 371 """callable returning a torch.nn.Module instance. 372 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 373 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 374 375 architecture_sha256: Annotated[ 376 Optional[Sha256], 377 Field( 378 description=( 379 "The SHA256 of the architecture source file, if the architecture is not" 380 " defined in a module listed in `dependencies`\n" 381 ) 382 + SHA256_HINT, 383 ), 384 ] = None 385 """The SHA256 of the architecture source file, 386 if the architecture is not defined in a module listed in `dependencies`""" 387 388 @model_validator(mode="after") 389 def check_architecture_sha256(self) -> Self: 390 if isinstance(self.architecture, CallableFromFile): 391 if self.architecture_sha256 is None: 392 raise ValueError( 393 "Missing required `architecture_sha256` for `architecture` with" 394 + " source file." 395 ) 396 elif self.architecture_sha256 is not None: 397 raise ValueError( 398 "Got `architecture_sha256` for architecture that does not have a source" 399 + " file." 400 ) 401 402 return self 403 404 kwargs: Dict[str, Any] = Field(default_factory=dict) 405 """key word arguments for the `architecture` callable""" 406 407 pytorch_version: Optional[Version] = None 408 """Version of the PyTorch library used. 409 If `depencencies` is specified it should include pytorch and the verison has to match. 410 (`dependencies` overrules `pytorch_version`)""" 411 412 @field_validator("pytorch_version", mode="after") 413 @classmethod 414 def _ptv(cls, value: Any): 415 if value is None: 416 issue_warning( 417 "missing. Please specify the PyTorch version these" 418 + " PyTorch state dict weights were created with.", 419 value=value, 420 severity=ALERT, 421 field="pytorch_version", 422 ) 423 return value
callable returning a torch.nn.Module instance.
Local implementation: <relative path to file>:<identifier of implementation within the file>
.
Implementation in a dependency: <dependency-package>.<[dependency-module]>.<identifier>
.
The SHA256 of the architecture source file,
if the architecture is not defined in a module listed in dependencies
388 @model_validator(mode="after") 389 def check_architecture_sha256(self) -> Self: 390 if isinstance(self.architecture, CallableFromFile): 391 if self.architecture_sha256 is None: 392 raise ValueError( 393 "Missing required `architecture_sha256` for `architecture` with" 394 + " source file." 395 ) 396 elif self.architecture_sha256 is not None: 397 raise ValueError( 398 "Got `architecture_sha256` for architecture that does not have a source" 399 + " file." 400 ) 401 402 return self
Version of the PyTorch library used.
If depencencies
is specified it should include pytorch and the verison has to match.
(dependencies
overrules pytorch_version
)
426class TorchscriptWeightsDescr(WeightsEntryDescrBase): 427 type = "torchscript" 428 weights_format_name: ClassVar[str] = "TorchScript" 429 pytorch_version: Optional[Version] = None 430 """Version of the PyTorch library used.""" 431 432 @field_validator("pytorch_version", mode="after") 433 @classmethod 434 def _ptv(cls, value: Any): 435 if value is None: 436 issue_warning( 437 "missing. Please specify the PyTorch version these" 438 + " Torchscript weights were created with.", 439 value=value, 440 severity=ALERT, 441 field="pytorch_version", 442 ) 443 return value
Version of the PyTorch library used.
446class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 447 type = "tensorflow_js" 448 weights_format_name: ClassVar[str] = "Tensorflow.js" 449 tensorflow_version: Optional[Version] = None 450 """Version of the TensorFlow library used.""" 451 452 @field_validator("tensorflow_version", mode="after") 453 @classmethod 454 def _tfv(cls, value: Any): 455 if value is None: 456 issue_warning( 457 "missing. Please specify the TensorFlow version" 458 + " these TensorflowJs weights were created with.", 459 value=value, 460 severity=ALERT, 461 field="tensorflow_version", 462 ) 463 return value 464 465 source: ImportantFileSource 466 """∈📦 The multi-file weights. 467 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
470class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 471 type = "tensorflow_saved_model_bundle" 472 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 473 tensorflow_version: Optional[Version] = None 474 """Version of the TensorFlow library used.""" 475 476 @field_validator("tensorflow_version", mode="after") 477 @classmethod 478 def _tfv(cls, value: Any): 479 if value is None: 480 issue_warning( 481 "missing. Please specify the TensorFlow version" 482 + " these Tensorflow saved model bundle weights were created with.", 483 value=value, 484 severity=ALERT, 485 field="tensorflow_version", 486 ) 487 return value
Version of the TensorFlow library used.
490class ParameterizedInputShape(Node): 491 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 492 493 min: NotEmpty[List[int]] 494 """The minimum input shape""" 495 496 step: NotEmpty[List[int]] 497 """The minimum shape change""" 498 499 def __len__(self) -> int: 500 return len(self.min) 501 502 @model_validator(mode="after") 503 def matching_lengths(self) -> Self: 504 if len(self.min) != len(self.step): 505 raise ValueError("`min` and `step` required to have the same length") 506 507 return self
A sequence of valid shapes given by shape_k = min + k * step for k in {0, 1, ...}
.
Inherited Members
510class ImplicitOutputShape(Node): 511 """Output tensor shape depending on an input tensor shape. 512 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 513 514 reference_tensor: TensorName 515 """Name of the reference tensor.""" 516 517 scale: NotEmpty[List[Optional[float]]] 518 """output_pix/input_pix for each dimension. 519 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 520 521 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 522 """Position of origin wrt to input.""" 523 524 def __len__(self) -> int: 525 return len(self.scale) 526 527 @model_validator(mode="after") 528 def matching_lengths(self) -> Self: 529 if len(self.scale) != len(self.offset): 530 raise ValueError( 531 f"scale {self.scale} has to have same length as offset {self.offset}!" 532 ) 533 # if we have an expanded dimension, make sure that it's offet is not zero 534 for sc, off in zip(self.scale, self.offset): 535 if sc is None and not off: 536 raise ValueError("`offset` must not be zero if `scale` is none/zero") 537 538 return self
Output tensor shape depending on an input tensor shape.
shape(output_tensor) = shape(input_tensor) * scale + 2 * offset
output_pix/input_pix for each dimension.
'null' values indicate new dimensions, whose length is defined by 2*offset
Position of origin wrt to input.
527 @model_validator(mode="after") 528 def matching_lengths(self) -> Self: 529 if len(self.scale) != len(self.offset): 530 raise ValueError( 531 f"scale {self.scale} has to have same length as offset {self.offset}!" 532 ) 533 # if we have an expanded dimension, make sure that it's offet is not zero 534 for sc, off in zip(self.scale, self.offset): 535 if sc is None and not off: 536 raise ValueError("`offset` must not be zero if `scale` is none/zero") 537 538 return self
Inherited Members
541class TensorDescrBase(Node): 542 name: TensorName 543 """Tensor name. No duplicates are allowed.""" 544 545 description: str = "" 546 547 axes: AxesStr 548 """Axes identifying characters. Same length and order as the axes in `shape`. 549 | axis | description | 550 | --- | --- | 551 | b | batch (groups multiple samples) | 552 | i | instance/index/element | 553 | t | time | 554 | c | channel | 555 | z | spatial dimension z | 556 | y | spatial dimension y | 557 | x | spatial dimension x | 558 """ 559 560 data_range: Optional[ 561 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 562 ] = None 563 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 564 If not specified, the full data range that can be expressed in `data_type` is allowed."""
Axes identifying characters. Same length and order as the axes in shape
.
axis | description |
---|---|
b | batch (groups multiple samples) |
i | instance/index/element |
t | time |
c | channel |
z | spatial dimension z |
y | spatial dimension y |
x | spatial dimension x |
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
If not specified, the full data range that can be expressed in data_type
is allowed.
Inherited Members
567class ProcessingKwargs(KwargsNode): 568 """base class for pre-/postprocessing key word arguments"""
base class for pre-/postprocessing key word arguments
571class ProcessingDescrBase(NodeWithExplicitlySetFields): 572 """processing base class""" 573 574 # name: Literal[PreprocessingName, PostprocessingName] # todo: make abstract field 575 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"name"})
processing base class
578class BinarizeKwargs(ProcessingKwargs): 579 """key word arguments for `BinarizeDescr`""" 580 581 threshold: float 582 """The fixed threshold"""
key word arguments for BinarizeDescr
585class BinarizeDescr(ProcessingDescrBase): 586 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 587 Values above the threshold will be set to one, values below the threshold to zero. 588 """ 589 590 name: Literal["binarize"] = "binarize" 591 kwargs: BinarizeKwargs
BinarizeDescr the tensor with a fixed BinarizeKwargs.threshold
.
Values above the threshold will be set to one, values below the threshold to zero.
594class ClipKwargs(ProcessingKwargs): 595 """key word arguments for `ClipDescr`""" 596 597 min: float 598 """minimum value for clipping""" 599 max: float 600 """maximum value for clipping"""
key word arguments for ClipDescr
603class ClipDescr(ProcessingDescrBase): 604 """Clip tensor values to a range. 605 606 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 607 and above `ClipKwargs.max` to `ClipKwargs.max`. 608 """ 609 610 name: Literal["clip"] = "clip" 611 612 kwargs: ClipKwargs
Clip tensor values to a range.
Set tensor values below ClipKwargs.min
to ClipKwargs.min
and above ClipKwargs.max
to ClipKwargs.max
.
615class ScaleLinearKwargs(ProcessingKwargs): 616 """key word arguments for `ScaleLinearDescr`""" 617 618 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 619 """The subset of axes to scale jointly. 620 For example xy to scale the two image axes for 2d data jointly.""" 621 622 gain: Union[float, List[float]] = 1.0 623 """multiplicative factor""" 624 625 offset: Union[float, List[float]] = 0.0 626 """additive term""" 627 628 @model_validator(mode="after") 629 def either_gain_or_offset(self) -> Self: 630 if ( 631 self.gain == 1.0 632 or isinstance(self.gain, list) 633 and all(g == 1.0 for g in self.gain) 634 ) and ( 635 self.offset == 0.0 636 or isinstance(self.offset, list) 637 and all(off == 0.0 for off in self.offset) 638 ): 639 raise ValueError( 640 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 641 + " 0.0." 642 ) 643 644 return self
key word arguments for ScaleLinearDescr
The subset of axes to scale jointly. For example xy to scale the two image axes for 2d data jointly.
628 @model_validator(mode="after") 629 def either_gain_or_offset(self) -> Self: 630 if ( 631 self.gain == 1.0 632 or isinstance(self.gain, list) 633 and all(g == 1.0 for g in self.gain) 634 ) and ( 635 self.offset == 0.0 636 or isinstance(self.offset, list) 637 and all(off == 0.0 for off in self.offset) 638 ): 639 raise ValueError( 640 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 641 + " 0.0." 642 ) 643 644 return self
647class ScaleLinearDescr(ProcessingDescrBase): 648 """Fixed linear scaling.""" 649 650 name: Literal["scale_linear"] = "scale_linear" 651 kwargs: ScaleLinearKwargs
Fixed linear scaling.
654class SigmoidDescr(ProcessingDescrBase): 655 """The logistic sigmoid funciton, a.k.a. expit function.""" 656 657 name: Literal["sigmoid"] = "sigmoid" 658 659 @property 660 def kwargs(self) -> ProcessingKwargs: 661 """empty kwargs""" 662 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
659 @property 660 def kwargs(self) -> ProcessingKwargs: 661 """empty kwargs""" 662 return ProcessingKwargs()
empty kwargs
665class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 666 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 667 668 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 669 """Mode for computing mean and variance. 670 | mode | description | 671 | ----------- | ------------------------------------ | 672 | fixed | Fixed values for mean and variance | 673 | per_dataset | Compute for the entire dataset | 674 | per_sample | Compute for each sample individually | 675 """ 676 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 677 """The subset of axes to normalize jointly. 678 For example `xy` to normalize the two image axes for 2d data jointly.""" 679 680 mean: Annotated[ 681 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 682 ] = None 683 """The mean value(s) to use for `mode: fixed`. 684 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 685 # todo: check if means match input axes (for mode 'fixed') 686 687 std: Annotated[ 688 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 689 ] = None 690 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 691 692 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 693 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 694 695 @model_validator(mode="after") 696 def mean_and_std_match_mode(self) -> Self: 697 if self.mode == "fixed" and (self.mean is None or self.std is None): 698 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 699 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 700 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 701 702 return self
key word arguments for ZeroMeanUnitVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
fixed | Fixed values for mean and variance |
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to normalize jointly.
For example xy
to normalize the two image axes for 2d data jointly.
The mean value(s) to use for mode: fixed
.
For example [1.1, 2.2, 3.3]
in the case of a 3 channel image with axes: xy
.
The standard deviation values to use for mode: fixed
. Analogous to mean.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
695 @model_validator(mode="after") 696 def mean_and_std_match_mode(self) -> Self: 697 if self.mode == "fixed" and (self.mean is None or self.std is None): 698 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 699 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 700 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 701 702 return self
705class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 706 """Subtract mean and divide by variance.""" 707 708 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 709 kwargs: ZeroMeanUnitVarianceKwargs
Subtract mean and divide by variance.
712class ScaleRangeKwargs(ProcessingKwargs): 713 """key word arguments for `ScaleRangeDescr` 714 715 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 716 this processing step normalizes data to the [0, 1] intervall. 717 For other percentiles the normalized values will partially be outside the [0, 1] 718 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 719 normalized values to a range. 720 """ 721 722 mode: Literal["per_dataset", "per_sample"] 723 """Mode for computing percentiles. 724 | mode | description | 725 | ----------- | ------------------------------------ | 726 | per_dataset | compute for the entire dataset | 727 | per_sample | compute for each sample individually | 728 """ 729 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 730 """The subset of axes to normalize jointly. 731 For example xy to normalize the two image axes for 2d data jointly.""" 732 733 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 734 """The lower percentile used to determine the value to align with zero.""" 735 736 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 737 """The upper percentile used to determine the value to align with one. 738 Has to be bigger than `min_percentile`. 739 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 740 accepting percentiles specified in the range 0.0 to 1.0.""" 741 742 @model_validator(mode="after") 743 def min_smaller_max(self, info: ValidationInfo) -> Self: 744 if self.min_percentile >= self.max_percentile: 745 raise ValueError( 746 f"min_percentile {self.min_percentile} >= max_percentile" 747 + f" {self.max_percentile}" 748 ) 749 750 return self 751 752 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 753 """Epsilon for numeric stability. 754 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 755 with `v_lower,v_upper` values at the respective percentiles.""" 756 757 reference_tensor: Optional[TensorName] = None 758 """Tensor name to compute the percentiles from. Default: The tensor itself. 759 For any tensor in `inputs` only input tensor references are allowed. 760 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
Mode for computing percentiles.
mode | description |
---|---|
per_dataset | compute for the entire dataset |
per_sample | compute for each sample individually |
The subset of axes to normalize jointly. For example xy to normalize the two image axes for 2d data jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor name to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
For a tensor in outputs
only input tensor refereences are allowed if mode: per_dataset
763class ScaleRangeDescr(ProcessingDescrBase): 764 """Scale with percentiles.""" 765 766 name: Literal["scale_range"] = "scale_range" 767 kwargs: ScaleRangeKwargs
Scale with percentiles.
770class ScaleMeanVarianceKwargs(ProcessingKwargs): 771 """key word arguments for `ScaleMeanVarianceDescr`""" 772 773 mode: Literal["per_dataset", "per_sample"] 774 """Mode for computing mean and variance. 775 | mode | description | 776 | ----------- | ------------------------------------ | 777 | per_dataset | Compute for the entire dataset | 778 | per_sample | Compute for each sample individually | 779 """ 780 781 reference_tensor: TensorName 782 """Name of tensor to match.""" 783 784 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 785 """The subset of axes to scale jointly. 786 For example xy to normalize the two image axes for 2d data jointly. 787 Default: scale all non-batch axes jointly.""" 788 789 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 790 """Epsilon for numeric stability: 791 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
key word arguments for ScaleMeanVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to scale jointly. For example xy to normalize the two image axes for 2d data jointly. Default: scale all non-batch axes jointly.
794class ScaleMeanVarianceDescr(ProcessingDescrBase): 795 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 796 797 name: Literal["scale_mean_variance"] = "scale_mean_variance" 798 kwargs: ScaleMeanVarianceKwargs
Scale the tensor s.t. its mean and variance match a reference tensor.
826class InputTensorDescr(TensorDescrBase): 827 data_type: Literal["float32", "uint8", "uint16"] 828 """For now an input tensor is expected to be given as `float32`. 829 The data flow in bioimage.io models is explained 830 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 831 832 shape: Annotated[ 833 Union[Sequence[int], ParameterizedInputShape], 834 Field( 835 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 836 ), 837 ] 838 """Specification of input tensor shape.""" 839 840 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 841 """Description of how this input should be preprocessed.""" 842 843 @model_validator(mode="after") 844 def zero_batch_step_and_one_batch_size(self) -> Self: 845 bidx = self.axes.find("b") 846 if bidx == -1: 847 return self 848 849 if isinstance(self.shape, ParameterizedInputShape): 850 step = self.shape.step 851 shape = self.shape.min 852 if step[bidx] != 0: 853 raise ValueError( 854 "Input shape step has to be zero in the batch dimension (the batch" 855 + " dimension can always be increased, but `step` should specify how" 856 + " to increase the minimal shape to find the largest single batch" 857 + " shape)" 858 ) 859 else: 860 shape = self.shape 861 862 if shape[bidx] != 1: 863 raise ValueError("Input shape has to be 1 in the batch dimension b.") 864 865 return self 866 867 @model_validator(mode="after") 868 def validate_preprocessing_kwargs(self) -> Self: 869 for p in self.preprocessing: 870 kwargs_axes = p.kwargs.get("axes", "") 871 if not isinstance(kwargs_axes, str): 872 raise ValueError( 873 f"Expected an `axes` string, but got {type(kwargs_axes)}" 874 ) 875 876 if any(a not in self.axes for a in kwargs_axes): 877 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 878 879 return self
For now an input tensor is expected to be given as float32
.
The data flow in bioimage.io models is explained
in this diagram..
Specification of input tensor shape.
Description of how this input should be preprocessed.
843 @model_validator(mode="after") 844 def zero_batch_step_and_one_batch_size(self) -> Self: 845 bidx = self.axes.find("b") 846 if bidx == -1: 847 return self 848 849 if isinstance(self.shape, ParameterizedInputShape): 850 step = self.shape.step 851 shape = self.shape.min 852 if step[bidx] != 0: 853 raise ValueError( 854 "Input shape step has to be zero in the batch dimension (the batch" 855 + " dimension can always be increased, but `step` should specify how" 856 + " to increase the minimal shape to find the largest single batch" 857 + " shape)" 858 ) 859 else: 860 shape = self.shape 861 862 if shape[bidx] != 1: 863 raise ValueError("Input shape has to be 1 in the batch dimension b.") 864 865 return self
867 @model_validator(mode="after") 868 def validate_preprocessing_kwargs(self) -> Self: 869 for p in self.preprocessing: 870 kwargs_axes = p.kwargs.get("axes", "") 871 if not isinstance(kwargs_axes, str): 872 raise ValueError( 873 f"Expected an `axes` string, but got {type(kwargs_axes)}" 874 ) 875 876 if any(a not in self.axes for a in kwargs_axes): 877 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 878 879 return self
Inherited Members
882class OutputTensorDescr(TensorDescrBase): 883 data_type: Literal[ 884 "float32", 885 "float64", 886 "uint8", 887 "int8", 888 "uint16", 889 "int16", 890 "uint32", 891 "int32", 892 "uint64", 893 "int64", 894 "bool", 895 ] 896 """Data type. 897 The data flow in bioimage.io models is explained 898 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 899 900 shape: Union[Sequence[int], ImplicitOutputShape] 901 """Output tensor shape.""" 902 903 halo: Optional[Sequence[int]] = None 904 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 905 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 906 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 907 908 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 909 """Description of how this output should be postprocessed.""" 910 911 @model_validator(mode="after") 912 def matching_halo_length(self) -> Self: 913 if self.halo and len(self.halo) != len(self.shape): 914 raise ValueError( 915 f"halo {self.halo} has to have same length as shape {self.shape}!" 916 ) 917 918 return self 919 920 @model_validator(mode="after") 921 def validate_postprocessing_kwargs(self) -> Self: 922 for p in self.postprocessing: 923 kwargs_axes = p.kwargs.get("axes", "") 924 if not isinstance(kwargs_axes, str): 925 raise ValueError(f"Expected {kwargs_axes} to be a string") 926 927 if any(a not in self.axes for a in kwargs_axes): 928 raise ValueError("`kwargs.axes` needs to be subset of axes") 929 930 return self
Data type. The data flow in bioimage.io models is explained in this diagram..
Description of how this output should be postprocessed.
920 @model_validator(mode="after") 921 def validate_postprocessing_kwargs(self) -> Self: 922 for p in self.postprocessing: 923 kwargs_axes = p.kwargs.get("axes", "") 924 if not isinstance(kwargs_axes, str): 925 raise ValueError(f"Expected {kwargs_axes} to be a string") 926 927 if any(a not in self.axes for a in kwargs_axes): 928 raise ValueError("`kwargs.axes` needs to be subset of axes") 929 930 return self
Inherited Members
936class RunMode(Node): 937 name: Annotated[ 938 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 939 ] 940 """Run mode name""" 941 942 kwargs: Dict[str, Any] = Field(default_factory=dict) 943 """Run mode specific key word arguments"""
Run mode name
Inherited Members
946class LinkedModel(Node): 947 """Reference to a bioimage.io model.""" 948 949 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 950 """A valid model `id` from the bioimage.io collection.""" 951 952 version_number: Optional[int] = None 953 """version number (n-th published version, not the semantic version) of linked model"""
Reference to a bioimage.io model.
A valid model id
from the bioimage.io collection.
version number (n-th published version, not the semantic version) of linked model
Inherited Members
956def package_weights( 957 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 958 handler: SerializerFunctionWrapHandler, 959 info: SerializationInfo, 960): 961 ctxt = packaging_context_var.get() 962 if ctxt is not None and ctxt.weights_priority_order is not None: 963 for wf in ctxt.weights_priority_order: 964 w = getattr(value, wf, None) 965 if w is not None: 966 break 967 else: 968 raise ValueError( 969 "None of the weight formats in `weights_priority_order`" 970 + f" ({ctxt.weights_priority_order}) is present in the given model." 971 ) 972 973 assert isinstance(w, Node), type(w) 974 # construct WeightsDescr with new single weight format entry 975 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 976 value = value.model_construct(None, **{wf: new_w}) 977 978 return handler( 979 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 980 )
983class ModelDescr(GenericModelDescrBase): 984 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 985 986 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 987 """ 988 989 format_version: Literal["0.4.10",] = "0.4.10" 990 """Version of the bioimage.io model description specification used. 991 When creating a new model always use the latest micro/patch version described here. 992 The `format_version` is important for any consumer software to understand how to parse the fields. 993 """ 994 995 type: Literal["model"] = "model" 996 """Specialized resource type 'model'""" 997 998 id: Optional[ModelId] = None 999 """bioimage.io-wide unique resource identifier 1000 assigned by bioimage.io; version **un**specific.""" 1001 1002 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 1003 List[Author] 1004 ] 1005 """The authors are the creators of the model RDF and the primary points of contact.""" 1006 1007 documentation: Annotated[ 1008 ImportantFileSource, 1009 Field( 1010 examples=[ 1011 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 1012 "README.md", 1013 ], 1014 ), 1015 ] 1016 """∈📦 URL or relative path to a markdown file with additional documentation. 1017 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 1018 The documentation should include a '[#[#]]# Validation' (sub)section 1019 with details on how to quantitatively validate the model on unseen data.""" 1020 1021 inputs: NotEmpty[List[InputTensorDescr]] 1022 """Describes the input tensors expected by this model.""" 1023 1024 license: Annotated[ 1025 Union[LicenseId, str], 1026 warn(LicenseId, "Unknown license id '{value}'."), 1027 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 1028 ] 1029 """A [SPDX license identifier](https://spdx.org/licenses/). 1030 We do notsupport custom license beyond the SPDX license list, if you need that please 1031 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 1032 ) to discuss your intentions with the community.""" 1033 1034 name: Annotated[ 1035 str, 1036 MinLen(1), 1037 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 1038 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 1039 ] 1040 """A human-readable name of this model. 1041 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 1042 1043 outputs: NotEmpty[List[OutputTensorDescr]] 1044 """Describes the output tensors.""" 1045 1046 @field_validator("inputs", "outputs") 1047 @classmethod 1048 def unique_tensor_descr_names( 1049 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1050 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1051 unique_names = {str(v.name) for v in value} 1052 if len(unique_names) != len(value): 1053 raise ValueError("Duplicate tensor descriptor names") 1054 1055 return value 1056 1057 @model_validator(mode="after") 1058 def unique_io_names(self) -> Self: 1059 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1060 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1061 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1062 1063 return self 1064 1065 @model_validator(mode="after") 1066 def minimum_shape2valid_output(self) -> Self: 1067 tensors_by_name: Dict[ 1068 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1069 ] = {t.name: t for t in self.inputs + self.outputs} 1070 1071 for out in self.outputs: 1072 if isinstance(out.shape, ImplicitOutputShape): 1073 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1074 ndim_out_ref = len( 1075 [scale for scale in out.shape.scale if scale is not None] 1076 ) 1077 if ndim_ref != ndim_out_ref: 1078 expanded_dim_note = ( 1079 " Note that expanded dimensions (`scale`: null) are not" 1080 + f" counted for {out.name}'sdimensionality here." 1081 if None in out.shape.scale 1082 else "" 1083 ) 1084 raise ValueError( 1085 f"Referenced tensor '{out.shape.reference_tensor}' with" 1086 + f" {ndim_ref} dimensions does not match output tensor" 1087 + f" '{out.name}' with" 1088 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1089 ) 1090 1091 min_out_shape = self._get_min_shape(out, tensors_by_name) 1092 if out.halo: 1093 halo = out.halo 1094 halo_msg = f" for halo {out.halo}" 1095 else: 1096 halo = [0] * len(min_out_shape) 1097 halo_msg = "" 1098 1099 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1100 raise ValueError( 1101 f"Minimal shape {min_out_shape} of output {out.name} is too" 1102 + f" small{halo_msg}." 1103 ) 1104 1105 return self 1106 1107 @classmethod 1108 def _get_min_shape( 1109 cls, 1110 t: Union[InputTensorDescr, OutputTensorDescr], 1111 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1112 ) -> Sequence[int]: 1113 """output with subtracted halo has to result in meaningful output even for the minimal input 1114 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1115 """ 1116 if isinstance(t.shape, collections.abc.Sequence): 1117 return t.shape 1118 elif isinstance(t.shape, ParameterizedInputShape): 1119 return t.shape.min 1120 elif isinstance(t.shape, ImplicitOutputShape): 1121 pass 1122 else: 1123 assert_never(t.shape) 1124 1125 ref_shape = cls._get_min_shape( 1126 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1127 ) 1128 1129 if None not in t.shape.scale: 1130 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1131 else: 1132 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1133 new_ref_shape: List[int] = [] 1134 for idx in range(len(t.shape.scale)): 1135 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1136 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1137 1138 ref_shape = new_ref_shape 1139 assert len(ref_shape) == len(t.shape.scale) 1140 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1141 1142 offset = t.shape.offset 1143 assert len(offset) == len(scale) 1144 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1145 1146 @model_validator(mode="after") 1147 def validate_tensor_references_in_inputs(self) -> Self: 1148 for t in self.inputs: 1149 for proc in t.preprocessing: 1150 if "reference_tensor" not in proc.kwargs: 1151 continue 1152 1153 ref_tensor = proc.kwargs["reference_tensor"] 1154 if ref_tensor is not None and str(ref_tensor) not in { 1155 str(t.name) for t in self.inputs 1156 }: 1157 raise ValueError(f"'{ref_tensor}' not found in inputs") 1158 1159 if ref_tensor == t.name: 1160 raise ValueError( 1161 f"invalid self reference for preprocessing of tensor {t.name}" 1162 ) 1163 1164 return self 1165 1166 @model_validator(mode="after") 1167 def validate_tensor_references_in_outputs(self) -> Self: 1168 for t in self.outputs: 1169 for proc in t.postprocessing: 1170 if "reference_tensor" not in proc.kwargs: 1171 continue 1172 ref_tensor = proc.kwargs["reference_tensor"] 1173 if ref_tensor is not None and str(ref_tensor) not in { 1174 str(t.name) for t in self.inputs 1175 }: 1176 raise ValueError(f"{ref_tensor} not found in inputs") 1177 1178 return self 1179 1180 packaged_by: List[Author] = Field(default_factory=list) 1181 """The persons that have packaged and uploaded this model. 1182 Only required if those persons differ from the `authors`.""" 1183 1184 parent: Optional[LinkedModel] = None 1185 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1186 1187 @field_validator("parent", mode="before") 1188 @classmethod 1189 def ignore_url_parent(cls, parent: Any): 1190 if isinstance(parent, dict): 1191 return None 1192 1193 else: 1194 return parent 1195 1196 run_mode: Optional[RunMode] = None 1197 """Custom run mode for this model: for more complex prediction procedures like test time 1198 data augmentation that currently cannot be expressed in the specification. 1199 No standard run modes are defined yet.""" 1200 1201 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1202 """∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1203 for example stored as PNG or TIFF images. 1204 The sample files primarily serve to inform a human user about an example use case""" 1205 1206 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1207 """∈📦 URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1208 1209 test_inputs: NotEmpty[ 1210 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1211 ] 1212 """∈📦 Test input tensors compatible with the `inputs` description for a **single test case**. 1213 This means if your model has more than one input, you should provide one URL/relative path for each input. 1214 Each test input should be a file with an ndarray in 1215 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1216 The extension must be '.npy'.""" 1217 1218 test_outputs: NotEmpty[ 1219 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1220 ] 1221 """∈📦 Analog to `test_inputs`.""" 1222 1223 timestamp: Datetime 1224 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1225 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1226 1227 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1228 """The dataset used to train this model""" 1229 1230 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1231 """The weights for this model. 1232 Weights can be given for different formats, but should otherwise be equivalent. 1233 The available weight formats determine which consumers can use this model.""" 1234 1235 @model_validator(mode="before") 1236 @classmethod 1237 def _convert_from_older_format( 1238 cls, data: BioimageioYamlContent, / 1239 ) -> BioimageioYamlContent: 1240 convert_from_older_format(data) 1241 return data 1242 1243 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1244 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1245 assert all(isinstance(d, np.ndarray) for d in data) 1246 return data 1247 1248 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1249 data = [load_array(download(out).path) for out in self.test_outputs] 1250 assert all(isinstance(d, np.ndarray) for d in data) 1251 return data
Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
Version of the bioimage.io model description specification used.
When creating a new model always use the latest micro/patch version described here.
The format_version
is important for any consumer software to understand how to parse the fields.
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '[#[#]]# Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A SPDX license identifier. We do notsupport custom license beyond the SPDX license list, if you need that please open a GitHub issue to discuss your intentions with the community.
A human-readable name of this model. It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.
1046 @field_validator("inputs", "outputs") 1047 @classmethod 1048 def unique_tensor_descr_names( 1049 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1050 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1051 unique_names = {str(v.name) for v in value} 1052 if len(unique_names) != len(value): 1053 raise ValueError("Duplicate tensor descriptor names") 1054 1055 return value
1057 @model_validator(mode="after") 1058 def unique_io_names(self) -> Self: 1059 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1060 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1061 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1062 1063 return self
1065 @model_validator(mode="after") 1066 def minimum_shape2valid_output(self) -> Self: 1067 tensors_by_name: Dict[ 1068 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1069 ] = {t.name: t for t in self.inputs + self.outputs} 1070 1071 for out in self.outputs: 1072 if isinstance(out.shape, ImplicitOutputShape): 1073 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1074 ndim_out_ref = len( 1075 [scale for scale in out.shape.scale if scale is not None] 1076 ) 1077 if ndim_ref != ndim_out_ref: 1078 expanded_dim_note = ( 1079 " Note that expanded dimensions (`scale`: null) are not" 1080 + f" counted for {out.name}'sdimensionality here." 1081 if None in out.shape.scale 1082 else "" 1083 ) 1084 raise ValueError( 1085 f"Referenced tensor '{out.shape.reference_tensor}' with" 1086 + f" {ndim_ref} dimensions does not match output tensor" 1087 + f" '{out.name}' with" 1088 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1089 ) 1090 1091 min_out_shape = self._get_min_shape(out, tensors_by_name) 1092 if out.halo: 1093 halo = out.halo 1094 halo_msg = f" for halo {out.halo}" 1095 else: 1096 halo = [0] * len(min_out_shape) 1097 halo_msg = "" 1098 1099 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1100 raise ValueError( 1101 f"Minimal shape {min_out_shape} of output {out.name} is too" 1102 + f" small{halo_msg}." 1103 ) 1104 1105 return self
1146 @model_validator(mode="after") 1147 def validate_tensor_references_in_inputs(self) -> Self: 1148 for t in self.inputs: 1149 for proc in t.preprocessing: 1150 if "reference_tensor" not in proc.kwargs: 1151 continue 1152 1153 ref_tensor = proc.kwargs["reference_tensor"] 1154 if ref_tensor is not None and str(ref_tensor) not in { 1155 str(t.name) for t in self.inputs 1156 }: 1157 raise ValueError(f"'{ref_tensor}' not found in inputs") 1158 1159 if ref_tensor == t.name: 1160 raise ValueError( 1161 f"invalid self reference for preprocessing of tensor {t.name}" 1162 ) 1163 1164 return self
1166 @model_validator(mode="after") 1167 def validate_tensor_references_in_outputs(self) -> Self: 1168 for t in self.outputs: 1169 for proc in t.postprocessing: 1170 if "reference_tensor" not in proc.kwargs: 1171 continue 1172 ref_tensor = proc.kwargs["reference_tensor"] 1173 if ref_tensor is not None and str(ref_tensor) not in { 1174 str(t.name) for t in self.inputs 1175 }: 1176 raise ValueError(f"{ref_tensor} not found in inputs") 1177 1178 return self
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, for example stored as PNG or TIFF images. The sample files primarily serve to inform a human user about an example use case
∈📦 URLs/relative paths to sample outputs corresponding to the sample_inputs
.
∈📦 Test input tensors compatible with the inputs
description for a single test case.
This means if your model has more than one input, you should provide one URL/relative path for each input.
Each test input should be a file with an ndarray in
numpy.lib file format.
The extension must be '.npy'.
∈📦 Analog to test_inputs
.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.