bioimageio.spec.model.v0_4
1from __future__ import annotations 2 3import collections.abc 4from typing import ( 5 TYPE_CHECKING, 6 Any, 7 ClassVar, 8 Dict, 9 List, 10 Literal, 11 Optional, 12 Sequence, 13 Tuple, 14 Type, 15 Union, 16) 17 18import numpy as np 19from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 20from numpy.typing import NDArray 21from pydantic import ( 22 AllowInfNan, 23 Discriminator, 24 Field, 25 RootModel, 26 SerializationInfo, 27 SerializerFunctionWrapHandler, 28 StringConstraints, 29 TypeAdapter, 30 ValidationInfo, 31 WrapSerializer, 32 field_validator, 33 model_validator, 34) 35from typing_extensions import Annotated, Self, assert_never, get_args 36 37from .._internal.common_nodes import ( 38 KwargsNode, 39 Node, 40 NodeWithExplicitlySetFields, 41) 42from .._internal.constants import SHA256_HINT 43from .._internal.field_validation import validate_unique_entries 44from .._internal.field_warning import issue_warning, warn 45from .._internal.io import ( 46 BioimageioYamlContent, 47 WithSuffix, 48 download, 49 include_in_package_serializer, 50) 51from .._internal.io import FileDescr as FileDescr 52from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 53from .._internal.io_basics import Sha256 as Sha256 54from .._internal.io_utils import load_array 55from .._internal.packaging_context import packaging_context_var 56from .._internal.types import Datetime as Datetime 57from .._internal.types import Identifier as Identifier 58from .._internal.types import ImportantFileSource, LowerCaseIdentifier 59from .._internal.types import LicenseId as LicenseId 60from .._internal.types import NotEmpty as NotEmpty 61from .._internal.url import HttpUrl as HttpUrl 62from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode 63from .._internal.validator_annotations import AfterValidator, RestrictCharacters 64from .._internal.version_type import Version as Version 65from .._internal.warning_levels import ALERT, INFO 66from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 67from ..dataset.v0_2 import DatasetDescr as DatasetDescr 68from ..dataset.v0_2 import LinkedDataset as LinkedDataset 69from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 70from ..generic.v0_2 import Author as Author 71from ..generic.v0_2 import BadgeDescr as BadgeDescr 72from ..generic.v0_2 import CiteEntry as CiteEntry 73from ..generic.v0_2 import Doi as Doi 74from ..generic.v0_2 import GenericModelDescrBase 75from ..generic.v0_2 import LinkedResource as LinkedResource 76from ..generic.v0_2 import Maintainer as Maintainer 77from ..generic.v0_2 import OrcidId as OrcidId 78from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 79from ..generic.v0_2 import ResourceId as ResourceId 80from ..generic.v0_2 import Uploader as Uploader 81from ._v0_4_converter import convert_from_older_format 82 83 84class ModelId(ResourceId): 85 pass 86 87 88AxesStr = Annotated[ 89 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 90] 91AxesInCZYX = Annotated[ 92 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 93] 94 95PostprocessingName = Literal[ 96 "binarize", 97 "clip", 98 "scale_linear", 99 "sigmoid", 100 "zero_mean_unit_variance", 101 "scale_range", 102 "scale_mean_variance", 103] 104PreprocessingName = Literal[ 105 "binarize", 106 "clip", 107 "scale_linear", 108 "sigmoid", 109 "zero_mean_unit_variance", 110 "scale_range", 111] 112 113 114class TensorName(LowerCaseIdentifier): 115 pass 116 117 118class CallableFromDepencencyNode(Node): 119 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 120 121 module_name: str 122 """The Python module that implements **callable_name**.""" 123 124 @field_validator("module_name", mode="after") 125 def _check_submodules(cls, module_name: str) -> str: 126 for submod in module_name.split("."): 127 _ = cls._submodule_adapter.validate_python(submod) 128 129 return module_name 130 131 callable_name: Identifier 132 """The callable Python identifier implemented in module **module_name**.""" 133 134 135class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 136 _inner_node_class = CallableFromDepencencyNode 137 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 138 Annotated[ 139 str, 140 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 141 ] 142 ] 143 144 @classmethod 145 def _get_data(cls, valid_string_data: str): 146 *mods, callname = valid_string_data.split(".") 147 return dict(module_name=".".join(mods), callable_name=callname) 148 149 @property 150 def module_name(self): 151 """The Python module that implements **callable_name**.""" 152 return self._inner_node.module_name 153 154 @property 155 def callable_name(self): 156 """The callable Python identifier implemented in module **module_name**.""" 157 return self._inner_node.callable_name 158 159 160class CallableFromFileNode(Node): 161 source_file: Annotated[ 162 Union[RelativeFilePath, HttpUrl], 163 Field(union_mode="left_to_right"), 164 include_in_package_serializer, 165 ] 166 """The Python source file that implements **callable_name**.""" 167 callable_name: Identifier 168 """The callable Python identifier implemented in **source_file**.""" 169 170 171class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 172 _inner_node_class = CallableFromFileNode 173 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 174 Annotated[ 175 str, 176 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 177 ] 178 ] 179 180 @classmethod 181 def _get_data(cls, valid_string_data: str): 182 *file_parts, callname = valid_string_data.split(":") 183 return dict(source_file=":".join(file_parts), callable_name=callname) 184 185 @property 186 def source_file(self): 187 """The Python source file that implements **callable_name**.""" 188 return self._inner_node.source_file 189 190 @property 191 def callable_name(self): 192 """The callable Python identifier implemented in **source_file**.""" 193 return self._inner_node.callable_name 194 195 196CustomCallable = Annotated[ 197 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 198] 199 200 201class DependenciesNode(Node): 202 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 203 """Dependency manager""" 204 205 file: ImportantFileSource 206 """Dependency file""" 207 208 209class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 210 _inner_node_class = DependenciesNode 211 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 212 Annotated[ 213 str, 214 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 215 ] 216 ] 217 218 @classmethod 219 def _get_data(cls, valid_string_data: str): 220 manager, *file_parts = valid_string_data.split(":") 221 return dict(manager=manager, file=":".join(file_parts)) 222 223 @property 224 def manager(self): 225 """Dependency manager""" 226 return self._inner_node.manager 227 228 @property 229 def file(self): 230 """Dependency file""" 231 return self._inner_node.file 232 233 234WeightsFormat = Literal[ 235 "keras_hdf5", 236 "onnx", 237 "pytorch_state_dict", 238 "tensorflow_js", 239 "tensorflow_saved_model_bundle", 240 "torchscript", 241] 242 243 244class WeightsDescr(Node): 245 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 246 onnx: Optional[OnnxWeightsDescr] = None 247 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 248 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 249 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 250 None 251 ) 252 torchscript: Optional[TorchscriptWeightsDescr] = None 253 254 @model_validator(mode="after") 255 def check_one_entry(self) -> Self: 256 if all( 257 entry is None 258 for entry in [ 259 self.keras_hdf5, 260 self.onnx, 261 self.pytorch_state_dict, 262 self.tensorflow_js, 263 self.tensorflow_saved_model_bundle, 264 self.torchscript, 265 ] 266 ): 267 raise ValueError("Missing weights entry") 268 269 return self 270 271 def __getitem__( 272 self, 273 key: WeightsFormat, 274 ): 275 if key == "keras_hdf5": 276 ret = self.keras_hdf5 277 elif key == "onnx": 278 ret = self.onnx 279 elif key == "pytorch_state_dict": 280 ret = self.pytorch_state_dict 281 elif key == "tensorflow_js": 282 ret = self.tensorflow_js 283 elif key == "tensorflow_saved_model_bundle": 284 ret = self.tensorflow_saved_model_bundle 285 elif key == "torchscript": 286 ret = self.torchscript 287 else: 288 raise KeyError(key) 289 290 if ret is None: 291 raise KeyError(key) 292 293 return ret 294 295 @property 296 def available_formats(self): 297 return { 298 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 299 **({} if self.onnx is None else {"onnx": self.onnx}), 300 **( 301 {} 302 if self.pytorch_state_dict is None 303 else {"pytorch_state_dict": self.pytorch_state_dict} 304 ), 305 **( 306 {} 307 if self.tensorflow_js is None 308 else {"tensorflow_js": self.tensorflow_js} 309 ), 310 **( 311 {} 312 if self.tensorflow_saved_model_bundle is None 313 else { 314 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 315 } 316 ), 317 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 318 } 319 320 @property 321 def missing_formats(self): 322 return { 323 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 324 } 325 326 327class WeightsEntryDescrBase(FileDescr): 328 type: ClassVar[WeightsFormat] 329 weights_format_name: ClassVar[str] # human readable 330 331 source: ImportantFileSource 332 """The weights file.""" 333 334 attachments: Annotated[ 335 Union[AttachmentsDescr, None], 336 warn(None, "Weights entry depends on additional attachments.", ALERT), 337 ] = None 338 """Attachments that are specific to this weights entry.""" 339 340 authors: Union[List[Author], None] = None 341 """Authors 342 Either the person(s) that have trained this model resulting in the original weights file. 343 (If this is the initial weights entry, i.e. it does not have a `parent`) 344 Or the person(s) who have converted the weights to this weights format. 345 (If this is a child weight, i.e. it has a `parent` field) 346 """ 347 348 dependencies: Annotated[ 349 Optional[Dependencies], 350 warn( 351 None, 352 "Custom dependencies ({value}) specified. Avoid this whenever possible " 353 + "to allow execution in a wider range of software environments.", 354 ), 355 Field( 356 examples=[ 357 "conda:environment.yaml", 358 "maven:./pom.xml", 359 "pip:./requirements.txt", 360 ] 361 ), 362 ] = None 363 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 364 365 parent: Annotated[ 366 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 367 ] = None 368 """The source weights these weights were converted from. 369 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 370 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 371 All weight entries except one (the initial set of weights resulting from training the model), 372 need to have this field.""" 373 374 @model_validator(mode="after") 375 def check_parent_is_not_self(self) -> Self: 376 if self.type == self.parent: 377 raise ValueError("Weights entry can't be it's own parent.") 378 379 return self 380 381 382class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 383 type = "keras_hdf5" 384 weights_format_name: ClassVar[str] = "Keras HDF5" 385 tensorflow_version: Optional[Version] = None 386 """TensorFlow version used to create these weights""" 387 388 @field_validator("tensorflow_version", mode="after") 389 @classmethod 390 def _tfv(cls, value: Any): 391 if value is None: 392 issue_warning( 393 "missing. Please specify the TensorFlow version" 394 + " these weights were created with.", 395 value=value, 396 severity=ALERT, 397 field="tensorflow_version", 398 ) 399 return value 400 401 402class OnnxWeightsDescr(WeightsEntryDescrBase): 403 type = "onnx" 404 weights_format_name: ClassVar[str] = "ONNX" 405 opset_version: Optional[Annotated[int, Ge(7)]] = None 406 """ONNX opset version""" 407 408 @field_validator("opset_version", mode="after") 409 @classmethod 410 def _ov(cls, value: Any): 411 if value is None: 412 issue_warning( 413 "Missing ONNX opset version (aka ONNX opset number). " 414 + "Please specify the ONNX opset version these weights were created" 415 + " with.", 416 value=value, 417 severity=ALERT, 418 field="opset_version", 419 ) 420 return value 421 422 423class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 424 type = "pytorch_state_dict" 425 weights_format_name: ClassVar[str] = "Pytorch State Dict" 426 architecture: CustomCallable = Field( 427 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 428 ) 429 """callable returning a torch.nn.Module instance. 430 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 431 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 432 433 architecture_sha256: Annotated[ 434 Optional[Sha256], 435 Field( 436 description=( 437 "The SHA256 of the architecture source file, if the architecture is not" 438 " defined in a module listed in `dependencies`\n" 439 ) 440 + SHA256_HINT, 441 ), 442 ] = None 443 """The SHA256 of the architecture source file, 444 if the architecture is not defined in a module listed in `dependencies`""" 445 446 @model_validator(mode="after") 447 def check_architecture_sha256(self) -> Self: 448 if isinstance(self.architecture, CallableFromFile): 449 if self.architecture_sha256 is None: 450 raise ValueError( 451 "Missing required `architecture_sha256` for `architecture` with" 452 + " source file." 453 ) 454 elif self.architecture_sha256 is not None: 455 raise ValueError( 456 "Got `architecture_sha256` for architecture that does not have a source" 457 + " file." 458 ) 459 460 return self 461 462 kwargs: Dict[str, Any] = Field(default_factory=dict) 463 """key word arguments for the `architecture` callable""" 464 465 pytorch_version: Optional[Version] = None 466 """Version of the PyTorch library used. 467 If `depencencies` is specified it should include pytorch and the verison has to match. 468 (`dependencies` overrules `pytorch_version`)""" 469 470 @field_validator("pytorch_version", mode="after") 471 @classmethod 472 def _ptv(cls, value: Any): 473 if value is None: 474 issue_warning( 475 "missing. Please specify the PyTorch version these" 476 + " PyTorch state dict weights were created with.", 477 value=value, 478 severity=ALERT, 479 field="pytorch_version", 480 ) 481 return value 482 483 484class TorchscriptWeightsDescr(WeightsEntryDescrBase): 485 type = "torchscript" 486 weights_format_name: ClassVar[str] = "TorchScript" 487 pytorch_version: Optional[Version] = None 488 """Version of the PyTorch library used.""" 489 490 @field_validator("pytorch_version", mode="after") 491 @classmethod 492 def _ptv(cls, value: Any): 493 if value is None: 494 issue_warning( 495 "missing. Please specify the PyTorch version these" 496 + " Torchscript weights were created with.", 497 value=value, 498 severity=ALERT, 499 field="pytorch_version", 500 ) 501 return value 502 503 504class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 505 type = "tensorflow_js" 506 weights_format_name: ClassVar[str] = "Tensorflow.js" 507 tensorflow_version: Optional[Version] = None 508 """Version of the TensorFlow library used.""" 509 510 @field_validator("tensorflow_version", mode="after") 511 @classmethod 512 def _tfv(cls, value: Any): 513 if value is None: 514 issue_warning( 515 "missing. Please specify the TensorFlow version" 516 + " these TensorflowJs weights were created with.", 517 value=value, 518 severity=ALERT, 519 field="tensorflow_version", 520 ) 521 return value 522 523 source: ImportantFileSource 524 """The multi-file weights. 525 All required files/folders should be a zip archive.""" 526 527 528class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 529 type = "tensorflow_saved_model_bundle" 530 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 531 tensorflow_version: Optional[Version] = None 532 """Version of the TensorFlow library used.""" 533 534 @field_validator("tensorflow_version", mode="after") 535 @classmethod 536 def _tfv(cls, value: Any): 537 if value is None: 538 issue_warning( 539 "missing. Please specify the TensorFlow version" 540 + " these Tensorflow saved model bundle weights were created with.", 541 value=value, 542 severity=ALERT, 543 field="tensorflow_version", 544 ) 545 return value 546 547 548class ParameterizedInputShape(Node): 549 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 550 551 min: NotEmpty[List[int]] 552 """The minimum input shape""" 553 554 step: NotEmpty[List[int]] 555 """The minimum shape change""" 556 557 def __len__(self) -> int: 558 return len(self.min) 559 560 @model_validator(mode="after") 561 def matching_lengths(self) -> Self: 562 if len(self.min) != len(self.step): 563 raise ValueError("`min` and `step` required to have the same length") 564 565 return self 566 567 568class ImplicitOutputShape(Node): 569 """Output tensor shape depending on an input tensor shape. 570 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 571 572 reference_tensor: TensorName 573 """Name of the reference tensor.""" 574 575 scale: NotEmpty[List[Optional[float]]] 576 """output_pix/input_pix for each dimension. 577 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 578 579 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 580 """Position of origin wrt to input.""" 581 582 def __len__(self) -> int: 583 return len(self.scale) 584 585 @model_validator(mode="after") 586 def matching_lengths(self) -> Self: 587 if len(self.scale) != len(self.offset): 588 raise ValueError( 589 f"scale {self.scale} has to have same length as offset {self.offset}!" 590 ) 591 # if we have an expanded dimension, make sure that it's offet is not zero 592 for sc, off in zip(self.scale, self.offset): 593 if sc is None and not off: 594 raise ValueError("`offset` must not be zero if `scale` is none/zero") 595 596 return self 597 598 599class TensorDescrBase(Node): 600 name: TensorName 601 """Tensor name. No duplicates are allowed.""" 602 603 description: str = "" 604 605 axes: AxesStr 606 """Axes identifying characters. Same length and order as the axes in `shape`. 607 | axis | description | 608 | --- | --- | 609 | b | batch (groups multiple samples) | 610 | i | instance/index/element | 611 | t | time | 612 | c | channel | 613 | z | spatial dimension z | 614 | y | spatial dimension y | 615 | x | spatial dimension x | 616 """ 617 618 data_range: Optional[ 619 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 620 ] = None 621 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 622 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 623 624 625class ProcessingKwargs(KwargsNode): 626 """base class for pre-/postprocessing key word arguments""" 627 628 629class ProcessingDescrBase(NodeWithExplicitlySetFields): 630 """processing base class""" 631 632 633class BinarizeKwargs(ProcessingKwargs): 634 """key word arguments for `BinarizeDescr`""" 635 636 threshold: float 637 """The fixed threshold""" 638 639 640class BinarizeDescr(ProcessingDescrBase): 641 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 642 Values above the threshold will be set to one, values below the threshold to zero. 643 """ 644 645 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 646 if TYPE_CHECKING: 647 name: Literal["binarize"] = "binarize" 648 else: 649 name: Literal["binarize"] 650 651 kwargs: BinarizeKwargs 652 653 654class ClipKwargs(ProcessingKwargs): 655 """key word arguments for `ClipDescr`""" 656 657 min: float 658 """minimum value for clipping""" 659 max: float 660 """maximum value for clipping""" 661 662 663class ClipDescr(ProcessingDescrBase): 664 """Clip tensor values to a range. 665 666 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 667 and above `ClipKwargs.max` to `ClipKwargs.max`. 668 """ 669 670 implemented_name: ClassVar[Literal["clip"]] = "clip" 671 if TYPE_CHECKING: 672 name: Literal["clip"] = "clip" 673 else: 674 name: Literal["clip"] 675 676 kwargs: ClipKwargs 677 678 679class ScaleLinearKwargs(ProcessingKwargs): 680 """key word arguments for `ScaleLinearDescr`""" 681 682 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 683 """The subset of axes to scale jointly. 684 For example xy to scale the two image axes for 2d data jointly.""" 685 686 gain: Union[float, List[float]] = 1.0 687 """multiplicative factor""" 688 689 offset: Union[float, List[float]] = 0.0 690 """additive term""" 691 692 @model_validator(mode="after") 693 def either_gain_or_offset(self) -> Self: 694 if ( 695 self.gain == 1.0 696 or isinstance(self.gain, list) 697 and all(g == 1.0 for g in self.gain) 698 ) and ( 699 self.offset == 0.0 700 or isinstance(self.offset, list) 701 and all(off == 0.0 for off in self.offset) 702 ): 703 raise ValueError( 704 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 705 + " 0.0." 706 ) 707 708 return self 709 710 711class ScaleLinearDescr(ProcessingDescrBase): 712 """Fixed linear scaling.""" 713 714 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 715 if TYPE_CHECKING: 716 name: Literal["scale_linear"] = "scale_linear" 717 else: 718 name: Literal["scale_linear"] 719 720 kwargs: ScaleLinearKwargs 721 722 723class SigmoidDescr(ProcessingDescrBase): 724 """The logistic sigmoid funciton, a.k.a. expit function.""" 725 726 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 727 if TYPE_CHECKING: 728 name: Literal["sigmoid"] = "sigmoid" 729 else: 730 name: Literal["sigmoid"] 731 732 @property 733 def kwargs(self) -> ProcessingKwargs: 734 """empty kwargs""" 735 return ProcessingKwargs() 736 737 738class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 739 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 740 741 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 742 """Mode for computing mean and variance. 743 | mode | description | 744 | ----------- | ------------------------------------ | 745 | fixed | Fixed values for mean and variance | 746 | per_dataset | Compute for the entire dataset | 747 | per_sample | Compute for each sample individually | 748 """ 749 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 750 """The subset of axes to normalize jointly. 751 For example `xy` to normalize the two image axes for 2d data jointly.""" 752 753 mean: Annotated[ 754 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 755 ] = None 756 """The mean value(s) to use for `mode: fixed`. 757 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 758 # todo: check if means match input axes (for mode 'fixed') 759 760 std: Annotated[ 761 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 762 ] = None 763 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 764 765 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 766 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 767 768 @model_validator(mode="after") 769 def mean_and_std_match_mode(self) -> Self: 770 if self.mode == "fixed" and (self.mean is None or self.std is None): 771 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 772 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 773 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 774 775 return self 776 777 778class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 779 """Subtract mean and divide by variance.""" 780 781 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 782 "zero_mean_unit_variance" 783 ) 784 if TYPE_CHECKING: 785 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 786 else: 787 name: Literal["zero_mean_unit_variance"] 788 789 kwargs: ZeroMeanUnitVarianceKwargs 790 791 792class ScaleRangeKwargs(ProcessingKwargs): 793 """key word arguments for `ScaleRangeDescr` 794 795 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 796 this processing step normalizes data to the [0, 1] intervall. 797 For other percentiles the normalized values will partially be outside the [0, 1] 798 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 799 normalized values to a range. 800 """ 801 802 mode: Literal["per_dataset", "per_sample"] 803 """Mode for computing percentiles. 804 | mode | description | 805 | ----------- | ------------------------------------ | 806 | per_dataset | compute for the entire dataset | 807 | per_sample | compute for each sample individually | 808 """ 809 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 810 """The subset of axes to normalize jointly. 811 For example xy to normalize the two image axes for 2d data jointly.""" 812 813 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 814 """The lower percentile used to determine the value to align with zero.""" 815 816 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 817 """The upper percentile used to determine the value to align with one. 818 Has to be bigger than `min_percentile`. 819 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 820 accepting percentiles specified in the range 0.0 to 1.0.""" 821 822 @model_validator(mode="after") 823 def min_smaller_max(self, info: ValidationInfo) -> Self: 824 if self.min_percentile >= self.max_percentile: 825 raise ValueError( 826 f"min_percentile {self.min_percentile} >= max_percentile" 827 + f" {self.max_percentile}" 828 ) 829 830 return self 831 832 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 833 """Epsilon for numeric stability. 834 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 835 with `v_lower,v_upper` values at the respective percentiles.""" 836 837 reference_tensor: Optional[TensorName] = None 838 """Tensor name to compute the percentiles from. Default: The tensor itself. 839 For any tensor in `inputs` only input tensor references are allowed. 840 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 841 842 843class ScaleRangeDescr(ProcessingDescrBase): 844 """Scale with percentiles.""" 845 846 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 847 if TYPE_CHECKING: 848 name: Literal["scale_range"] = "scale_range" 849 else: 850 name: Literal["scale_range"] 851 852 kwargs: ScaleRangeKwargs 853 854 855class ScaleMeanVarianceKwargs(ProcessingKwargs): 856 """key word arguments for `ScaleMeanVarianceDescr`""" 857 858 mode: Literal["per_dataset", "per_sample"] 859 """Mode for computing mean and variance. 860 | mode | description | 861 | ----------- | ------------------------------------ | 862 | per_dataset | Compute for the entire dataset | 863 | per_sample | Compute for each sample individually | 864 """ 865 866 reference_tensor: TensorName 867 """Name of tensor to match.""" 868 869 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 870 """The subset of axes to scale jointly. 871 For example xy to normalize the two image axes for 2d data jointly. 872 Default: scale all non-batch axes jointly.""" 873 874 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 875 """Epsilon for numeric stability: 876 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 877 878 879class ScaleMeanVarianceDescr(ProcessingDescrBase): 880 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 881 882 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 883 if TYPE_CHECKING: 884 name: Literal["scale_mean_variance"] = "scale_mean_variance" 885 else: 886 name: Literal["scale_mean_variance"] 887 888 kwargs: ScaleMeanVarianceKwargs 889 890 891PreprocessingDescr = Annotated[ 892 Union[ 893 BinarizeDescr, 894 ClipDescr, 895 ScaleLinearDescr, 896 SigmoidDescr, 897 ZeroMeanUnitVarianceDescr, 898 ScaleRangeDescr, 899 ], 900 Discriminator("name"), 901] 902PostprocessingDescr = Annotated[ 903 Union[ 904 BinarizeDescr, 905 ClipDescr, 906 ScaleLinearDescr, 907 SigmoidDescr, 908 ZeroMeanUnitVarianceDescr, 909 ScaleRangeDescr, 910 ScaleMeanVarianceDescr, 911 ], 912 Discriminator("name"), 913] 914 915 916class InputTensorDescr(TensorDescrBase): 917 data_type: Literal["float32", "uint8", "uint16"] 918 """For now an input tensor is expected to be given as `float32`. 919 The data flow in bioimage.io models is explained 920 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 921 922 shape: Annotated[ 923 Union[Sequence[int], ParameterizedInputShape], 924 Field( 925 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 926 ), 927 ] 928 """Specification of input tensor shape.""" 929 930 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 931 """Description of how this input should be preprocessed.""" 932 933 @model_validator(mode="after") 934 def zero_batch_step_and_one_batch_size(self) -> Self: 935 bidx = self.axes.find("b") 936 if bidx == -1: 937 return self 938 939 if isinstance(self.shape, ParameterizedInputShape): 940 step = self.shape.step 941 shape = self.shape.min 942 if step[bidx] != 0: 943 raise ValueError( 944 "Input shape step has to be zero in the batch dimension (the batch" 945 + " dimension can always be increased, but `step` should specify how" 946 + " to increase the minimal shape to find the largest single batch" 947 + " shape)" 948 ) 949 else: 950 shape = self.shape 951 952 if shape[bidx] != 1: 953 raise ValueError("Input shape has to be 1 in the batch dimension b.") 954 955 return self 956 957 @model_validator(mode="after") 958 def validate_preprocessing_kwargs(self) -> Self: 959 for p in self.preprocessing: 960 kwargs_axes = p.kwargs.get("axes") 961 if isinstance(kwargs_axes, str) and any( 962 a not in self.axes for a in kwargs_axes 963 ): 964 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 965 966 return self 967 968 969class OutputTensorDescr(TensorDescrBase): 970 data_type: Literal[ 971 "float32", 972 "float64", 973 "uint8", 974 "int8", 975 "uint16", 976 "int16", 977 "uint32", 978 "int32", 979 "uint64", 980 "int64", 981 "bool", 982 ] 983 """Data type. 984 The data flow in bioimage.io models is explained 985 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 986 987 shape: Union[Sequence[int], ImplicitOutputShape] 988 """Output tensor shape.""" 989 990 halo: Optional[Sequence[int]] = None 991 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 992 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 993 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 994 995 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 996 """Description of how this output should be postprocessed.""" 997 998 @model_validator(mode="after") 999 def matching_halo_length(self) -> Self: 1000 if self.halo and len(self.halo) != len(self.shape): 1001 raise ValueError( 1002 f"halo {self.halo} has to have same length as shape {self.shape}!" 1003 ) 1004 1005 return self 1006 1007 @model_validator(mode="after") 1008 def validate_postprocessing_kwargs(self) -> Self: 1009 for p in self.postprocessing: 1010 kwargs_axes = p.kwargs.get("axes", "") 1011 if not isinstance(kwargs_axes, str): 1012 raise ValueError(f"Expected {kwargs_axes} to be a string") 1013 1014 if any(a not in self.axes for a in kwargs_axes): 1015 raise ValueError("`kwargs.axes` needs to be subset of axes") 1016 1017 return self 1018 1019 1020KnownRunMode = Literal["deepimagej"] 1021 1022 1023class RunMode(Node): 1024 name: Annotated[ 1025 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 1026 ] 1027 """Run mode name""" 1028 1029 kwargs: Dict[str, Any] = Field(default_factory=dict) 1030 """Run mode specific key word arguments""" 1031 1032 1033class LinkedModel(Node): 1034 """Reference to a bioimage.io model.""" 1035 1036 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 1037 """A valid model `id` from the bioimage.io collection.""" 1038 1039 version_number: Optional[int] = None 1040 """version number (n-th published version, not the semantic version) of linked model""" 1041 1042 1043def package_weights( 1044 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 1045 handler: SerializerFunctionWrapHandler, 1046 info: SerializationInfo, 1047): 1048 ctxt = packaging_context_var.get() 1049 if ctxt is not None and ctxt.weights_priority_order is not None: 1050 for wf in ctxt.weights_priority_order: 1051 w = getattr(value, wf, None) 1052 if w is not None: 1053 break 1054 else: 1055 raise ValueError( 1056 "None of the weight formats in `weights_priority_order`" 1057 + f" ({ctxt.weights_priority_order}) is present in the given model." 1058 ) 1059 1060 assert isinstance(w, Node), type(w) 1061 # construct WeightsDescr with new single weight format entry 1062 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 1063 value = value.model_construct(None, **{wf: new_w}) 1064 1065 return handler( 1066 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 1067 ) 1068 1069 1070class ModelDescr(GenericModelDescrBase): 1071 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 1072 1073 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 1074 """ 1075 1076 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 1077 if TYPE_CHECKING: 1078 format_version: Literal["0.4.10"] = "0.4.10" 1079 else: 1080 format_version: Literal["0.4.10"] 1081 """Version of the bioimage.io model description specification used. 1082 When creating a new model always use the latest micro/patch version described here. 1083 The `format_version` is important for any consumer software to understand how to parse the fields. 1084 """ 1085 1086 implemented_type: ClassVar[Literal["model"]] = "model" 1087 if TYPE_CHECKING: 1088 type: Literal["model"] = "model" 1089 else: 1090 type: Literal["model"] 1091 """Specialized resource type 'model'""" 1092 1093 id: Optional[ModelId] = None 1094 """bioimage.io-wide unique resource identifier 1095 assigned by bioimage.io; version **un**specific.""" 1096 1097 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 1098 List[Author] 1099 ] 1100 """The authors are the creators of the model RDF and the primary points of contact.""" 1101 1102 documentation: Annotated[ 1103 ImportantFileSource, 1104 Field( 1105 examples=[ 1106 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 1107 "README.md", 1108 ], 1109 ), 1110 ] 1111 """URL or relative path to a markdown file with additional documentation. 1112 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 1113 The documentation should include a '[#[#]]# Validation' (sub)section 1114 with details on how to quantitatively validate the model on unseen data.""" 1115 1116 inputs: NotEmpty[List[InputTensorDescr]] 1117 """Describes the input tensors expected by this model.""" 1118 1119 license: Annotated[ 1120 Union[LicenseId, str], 1121 warn(LicenseId, "Unknown license id '{value}'."), 1122 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 1123 ] 1124 """A [SPDX license identifier](https://spdx.org/licenses/). 1125 We do notsupport custom license beyond the SPDX license list, if you need that please 1126 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 1127 ) to discuss your intentions with the community.""" 1128 1129 name: Annotated[ 1130 str, 1131 MinLen(1), 1132 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 1133 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 1134 ] 1135 """A human-readable name of this model. 1136 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 1137 1138 outputs: NotEmpty[List[OutputTensorDescr]] 1139 """Describes the output tensors.""" 1140 1141 @field_validator("inputs", "outputs") 1142 @classmethod 1143 def unique_tensor_descr_names( 1144 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1145 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1146 unique_names = {str(v.name) for v in value} 1147 if len(unique_names) != len(value): 1148 raise ValueError("Duplicate tensor descriptor names") 1149 1150 return value 1151 1152 @model_validator(mode="after") 1153 def unique_io_names(self) -> Self: 1154 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1155 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1156 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1157 1158 return self 1159 1160 @model_validator(mode="after") 1161 def minimum_shape2valid_output(self) -> Self: 1162 tensors_by_name: Dict[ 1163 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1164 ] = {t.name: t for t in self.inputs + self.outputs} 1165 1166 for out in self.outputs: 1167 if isinstance(out.shape, ImplicitOutputShape): 1168 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1169 ndim_out_ref = len( 1170 [scale for scale in out.shape.scale if scale is not None] 1171 ) 1172 if ndim_ref != ndim_out_ref: 1173 expanded_dim_note = ( 1174 " Note that expanded dimensions (`scale`: null) are not" 1175 + f" counted for {out.name}'sdimensionality here." 1176 if None in out.shape.scale 1177 else "" 1178 ) 1179 raise ValueError( 1180 f"Referenced tensor '{out.shape.reference_tensor}' with" 1181 + f" {ndim_ref} dimensions does not match output tensor" 1182 + f" '{out.name}' with" 1183 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1184 ) 1185 1186 min_out_shape = self._get_min_shape(out, tensors_by_name) 1187 if out.halo: 1188 halo = out.halo 1189 halo_msg = f" for halo {out.halo}" 1190 else: 1191 halo = [0] * len(min_out_shape) 1192 halo_msg = "" 1193 1194 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1195 raise ValueError( 1196 f"Minimal shape {min_out_shape} of output {out.name} is too" 1197 + f" small{halo_msg}." 1198 ) 1199 1200 return self 1201 1202 @classmethod 1203 def _get_min_shape( 1204 cls, 1205 t: Union[InputTensorDescr, OutputTensorDescr], 1206 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1207 ) -> Sequence[int]: 1208 """output with subtracted halo has to result in meaningful output even for the minimal input 1209 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1210 """ 1211 if isinstance(t.shape, collections.abc.Sequence): 1212 return t.shape 1213 elif isinstance(t.shape, ParameterizedInputShape): 1214 return t.shape.min 1215 elif isinstance(t.shape, ImplicitOutputShape): 1216 pass 1217 else: 1218 assert_never(t.shape) 1219 1220 ref_shape = cls._get_min_shape( 1221 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1222 ) 1223 1224 if None not in t.shape.scale: 1225 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1226 else: 1227 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1228 new_ref_shape: List[int] = [] 1229 for idx in range(len(t.shape.scale)): 1230 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1231 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1232 1233 ref_shape = new_ref_shape 1234 assert len(ref_shape) == len(t.shape.scale) 1235 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1236 1237 offset = t.shape.offset 1238 assert len(offset) == len(scale) 1239 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1240 1241 @model_validator(mode="after") 1242 def validate_tensor_references_in_inputs(self) -> Self: 1243 for t in self.inputs: 1244 for proc in t.preprocessing: 1245 if "reference_tensor" not in proc.kwargs: 1246 continue 1247 1248 ref_tensor = proc.kwargs["reference_tensor"] 1249 if ref_tensor is not None and str(ref_tensor) not in { 1250 str(t.name) for t in self.inputs 1251 }: 1252 raise ValueError(f"'{ref_tensor}' not found in inputs") 1253 1254 if ref_tensor == t.name: 1255 raise ValueError( 1256 f"invalid self reference for preprocessing of tensor {t.name}" 1257 ) 1258 1259 return self 1260 1261 @model_validator(mode="after") 1262 def validate_tensor_references_in_outputs(self) -> Self: 1263 for t in self.outputs: 1264 for proc in t.postprocessing: 1265 if "reference_tensor" not in proc.kwargs: 1266 continue 1267 ref_tensor = proc.kwargs["reference_tensor"] 1268 if ref_tensor is not None and str(ref_tensor) not in { 1269 str(t.name) for t in self.inputs 1270 }: 1271 raise ValueError(f"{ref_tensor} not found in inputs") 1272 1273 return self 1274 1275 packaged_by: List[Author] = Field(default_factory=list) 1276 """The persons that have packaged and uploaded this model. 1277 Only required if those persons differ from the `authors`.""" 1278 1279 parent: Optional[LinkedModel] = None 1280 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1281 1282 @field_validator("parent", mode="before") 1283 @classmethod 1284 def ignore_url_parent(cls, parent: Any): 1285 if isinstance(parent, dict): 1286 return None 1287 1288 else: 1289 return parent 1290 1291 run_mode: Optional[RunMode] = None 1292 """Custom run mode for this model: for more complex prediction procedures like test time 1293 data augmentation that currently cannot be expressed in the specification. 1294 No standard run modes are defined yet.""" 1295 1296 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1297 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1298 for example stored as PNG or TIFF images. 1299 The sample files primarily serve to inform a human user about an example use case""" 1300 1301 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1302 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1303 1304 test_inputs: NotEmpty[ 1305 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1306 ] 1307 """Test input tensors compatible with the `inputs` description for a **single test case**. 1308 This means if your model has more than one input, you should provide one URL/relative path for each input. 1309 Each test input should be a file with an ndarray in 1310 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1311 The extension must be '.npy'.""" 1312 1313 test_outputs: NotEmpty[ 1314 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1315 ] 1316 """Analog to `test_inputs`.""" 1317 1318 timestamp: Datetime 1319 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1320 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1321 1322 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1323 """The dataset used to train this model""" 1324 1325 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1326 """The weights for this model. 1327 Weights can be given for different formats, but should otherwise be equivalent. 1328 The available weight formats determine which consumers can use this model.""" 1329 1330 @model_validator(mode="before") 1331 @classmethod 1332 def _convert_from_older_format( 1333 cls, data: BioimageioYamlContent, / 1334 ) -> BioimageioYamlContent: 1335 convert_from_older_format(data) 1336 return data 1337 1338 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1339 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1340 assert all(isinstance(d, np.ndarray) for d in data) 1341 return data 1342 1343 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1344 data = [load_array(download(out).path) for out in self.test_outputs] 1345 assert all(isinstance(d, np.ndarray) for d in data) 1346 return data
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
119class CallableFromDepencencyNode(Node): 120 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 121 122 module_name: str 123 """The Python module that implements **callable_name**.""" 124 125 @field_validator("module_name", mode="after") 126 def _check_submodules(cls, module_name: str) -> str: 127 for submod in module_name.split("."): 128 _ = cls._submodule_adapter.validate_python(submod) 129 130 return module_name 131 132 callable_name: Identifier 133 """The callable Python identifier implemented in module **module_name**."""
The callable Python identifier implemented in module module_name.
306def init_private_attributes(self: BaseModel, context: Any, /) -> None: 307 """This function is meant to behave like a BaseModel method to initialise private attributes. 308 309 It takes context as an argument since that's what pydantic-core passes when calling it. 310 311 Args: 312 self: The BaseModel instance. 313 context: The context. 314 """ 315 if getattr(self, '__pydantic_private__', None) is None: 316 pydantic_private = {} 317 for name, private_attr in self.__private_attributes__.items(): 318 default = private_attr.get_default() 319 if default is not PydanticUndefined: 320 pydantic_private[name] = default 321 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
136class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 137 _inner_node_class = CallableFromDepencencyNode 138 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 139 Annotated[ 140 str, 141 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 142 ] 143 ] 144 145 @classmethod 146 def _get_data(cls, valid_string_data: str): 147 *mods, callname = valid_string_data.split(".") 148 return dict(module_name=".".join(mods), callable_name=callname) 149 150 @property 151 def module_name(self): 152 """The Python module that implements **callable_name**.""" 153 return self._inner_node.module_name 154 155 @property 156 def callable_name(self): 157 """The callable Python identifier implemented in module **module_name**.""" 158 return self._inner_node.callable_name
A validated string with further validation and serialization using a Node
the pydantic root model to validate the string
161class CallableFromFileNode(Node): 162 source_file: Annotated[ 163 Union[RelativeFilePath, HttpUrl], 164 Field(union_mode="left_to_right"), 165 include_in_package_serializer, 166 ] 167 """The Python source file that implements **callable_name**.""" 168 callable_name: Identifier 169 """The callable Python identifier implemented in **source_file**."""
The Python source file that implements callable_name.
172class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 173 _inner_node_class = CallableFromFileNode 174 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 175 Annotated[ 176 str, 177 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 178 ] 179 ] 180 181 @classmethod 182 def _get_data(cls, valid_string_data: str): 183 *file_parts, callname = valid_string_data.split(":") 184 return dict(source_file=":".join(file_parts), callable_name=callname) 185 186 @property 187 def source_file(self): 188 """The Python source file that implements **callable_name**.""" 189 return self._inner_node.source_file 190 191 @property 192 def callable_name(self): 193 """The callable Python identifier implemented in **source_file**.""" 194 return self._inner_node.callable_name
A validated string with further validation and serialization using a Node
the pydantic root model to validate the string
202class DependenciesNode(Node): 203 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 204 """Dependency manager""" 205 206 file: ImportantFileSource 207 """Dependency file"""
Dependency manager
Dependency file
210class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 211 _inner_node_class = DependenciesNode 212 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 213 Annotated[ 214 str, 215 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 216 ] 217 ] 218 219 @classmethod 220 def _get_data(cls, valid_string_data: str): 221 manager, *file_parts = valid_string_data.split(":") 222 return dict(manager=manager, file=":".join(file_parts)) 223 224 @property 225 def manager(self): 226 """Dependency manager""" 227 return self._inner_node.manager 228 229 @property 230 def file(self): 231 """Dependency file""" 232 return self._inner_node.file
A validated string with further validation and serialization using a Node
the pydantic root model to validate the string
245class WeightsDescr(Node): 246 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 247 onnx: Optional[OnnxWeightsDescr] = None 248 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 249 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 250 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 251 None 252 ) 253 torchscript: Optional[TorchscriptWeightsDescr] = None 254 255 @model_validator(mode="after") 256 def check_one_entry(self) -> Self: 257 if all( 258 entry is None 259 for entry in [ 260 self.keras_hdf5, 261 self.onnx, 262 self.pytorch_state_dict, 263 self.tensorflow_js, 264 self.tensorflow_saved_model_bundle, 265 self.torchscript, 266 ] 267 ): 268 raise ValueError("Missing weights entry") 269 270 return self 271 272 def __getitem__( 273 self, 274 key: WeightsFormat, 275 ): 276 if key == "keras_hdf5": 277 ret = self.keras_hdf5 278 elif key == "onnx": 279 ret = self.onnx 280 elif key == "pytorch_state_dict": 281 ret = self.pytorch_state_dict 282 elif key == "tensorflow_js": 283 ret = self.tensorflow_js 284 elif key == "tensorflow_saved_model_bundle": 285 ret = self.tensorflow_saved_model_bundle 286 elif key == "torchscript": 287 ret = self.torchscript 288 else: 289 raise KeyError(key) 290 291 if ret is None: 292 raise KeyError(key) 293 294 return ret 295 296 @property 297 def available_formats(self): 298 return { 299 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 300 **({} if self.onnx is None else {"onnx": self.onnx}), 301 **( 302 {} 303 if self.pytorch_state_dict is None 304 else {"pytorch_state_dict": self.pytorch_state_dict} 305 ), 306 **( 307 {} 308 if self.tensorflow_js is None 309 else {"tensorflow_js": self.tensorflow_js} 310 ), 311 **( 312 {} 313 if self.tensorflow_saved_model_bundle is None 314 else { 315 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 316 } 317 ), 318 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 319 } 320 321 @property 322 def missing_formats(self): 323 return { 324 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 325 }
255 @model_validator(mode="after") 256 def check_one_entry(self) -> Self: 257 if all( 258 entry is None 259 for entry in [ 260 self.keras_hdf5, 261 self.onnx, 262 self.pytorch_state_dict, 263 self.tensorflow_js, 264 self.tensorflow_saved_model_bundle, 265 self.torchscript, 266 ] 267 ): 268 raise ValueError("Missing weights entry") 269 270 return self
296 @property 297 def available_formats(self): 298 return { 299 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 300 **({} if self.onnx is None else {"onnx": self.onnx}), 301 **( 302 {} 303 if self.pytorch_state_dict is None 304 else {"pytorch_state_dict": self.pytorch_state_dict} 305 ), 306 **( 307 {} 308 if self.tensorflow_js is None 309 else {"tensorflow_js": self.tensorflow_js} 310 ), 311 **( 312 {} 313 if self.tensorflow_saved_model_bundle is None 314 else { 315 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 316 } 317 ), 318 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 319 }
328class WeightsEntryDescrBase(FileDescr): 329 type: ClassVar[WeightsFormat] 330 weights_format_name: ClassVar[str] # human readable 331 332 source: ImportantFileSource 333 """The weights file.""" 334 335 attachments: Annotated[ 336 Union[AttachmentsDescr, None], 337 warn(None, "Weights entry depends on additional attachments.", ALERT), 338 ] = None 339 """Attachments that are specific to this weights entry.""" 340 341 authors: Union[List[Author], None] = None 342 """Authors 343 Either the person(s) that have trained this model resulting in the original weights file. 344 (If this is the initial weights entry, i.e. it does not have a `parent`) 345 Or the person(s) who have converted the weights to this weights format. 346 (If this is a child weight, i.e. it has a `parent` field) 347 """ 348 349 dependencies: Annotated[ 350 Optional[Dependencies], 351 warn( 352 None, 353 "Custom dependencies ({value}) specified. Avoid this whenever possible " 354 + "to allow execution in a wider range of software environments.", 355 ), 356 Field( 357 examples=[ 358 "conda:environment.yaml", 359 "maven:./pom.xml", 360 "pip:./requirements.txt", 361 ] 362 ), 363 ] = None 364 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 365 366 parent: Annotated[ 367 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 368 ] = None 369 """The source weights these weights were converted from. 370 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 371 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 372 All weight entries except one (the initial set of weights resulting from training the model), 373 need to have this field.""" 374 375 @model_validator(mode="after") 376 def check_parent_is_not_self(self) -> Self: 377 if self.type == self.parent: 378 raise ValueError("Weights entry can't be it's own parent.") 379 380 return self
The weights file.
Attachments that are specific to this weights entry.
Dependency manager and dependency file, specified as <dependency manager>:<relative file path>
.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
383class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 384 type = "keras_hdf5" 385 weights_format_name: ClassVar[str] = "Keras HDF5" 386 tensorflow_version: Optional[Version] = None 387 """TensorFlow version used to create these weights""" 388 389 @field_validator("tensorflow_version", mode="after") 390 @classmethod 391 def _tfv(cls, value: Any): 392 if value is None: 393 issue_warning( 394 "missing. Please specify the TensorFlow version" 395 + " these weights were created with.", 396 value=value, 397 severity=ALERT, 398 field="tensorflow_version", 399 ) 400 return value
TensorFlow version used to create these weights
Inherited Members
403class OnnxWeightsDescr(WeightsEntryDescrBase): 404 type = "onnx" 405 weights_format_name: ClassVar[str] = "ONNX" 406 opset_version: Optional[Annotated[int, Ge(7)]] = None 407 """ONNX opset version""" 408 409 @field_validator("opset_version", mode="after") 410 @classmethod 411 def _ov(cls, value: Any): 412 if value is None: 413 issue_warning( 414 "Missing ONNX opset version (aka ONNX opset number). " 415 + "Please specify the ONNX opset version these weights were created" 416 + " with.", 417 value=value, 418 severity=ALERT, 419 field="opset_version", 420 ) 421 return value
Inherited Members
424class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 425 type = "pytorch_state_dict" 426 weights_format_name: ClassVar[str] = "Pytorch State Dict" 427 architecture: CustomCallable = Field( 428 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 429 ) 430 """callable returning a torch.nn.Module instance. 431 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 432 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 433 434 architecture_sha256: Annotated[ 435 Optional[Sha256], 436 Field( 437 description=( 438 "The SHA256 of the architecture source file, if the architecture is not" 439 " defined in a module listed in `dependencies`\n" 440 ) 441 + SHA256_HINT, 442 ), 443 ] = None 444 """The SHA256 of the architecture source file, 445 if the architecture is not defined in a module listed in `dependencies`""" 446 447 @model_validator(mode="after") 448 def check_architecture_sha256(self) -> Self: 449 if isinstance(self.architecture, CallableFromFile): 450 if self.architecture_sha256 is None: 451 raise ValueError( 452 "Missing required `architecture_sha256` for `architecture` with" 453 + " source file." 454 ) 455 elif self.architecture_sha256 is not None: 456 raise ValueError( 457 "Got `architecture_sha256` for architecture that does not have a source" 458 + " file." 459 ) 460 461 return self 462 463 kwargs: Dict[str, Any] = Field(default_factory=dict) 464 """key word arguments for the `architecture` callable""" 465 466 pytorch_version: Optional[Version] = None 467 """Version of the PyTorch library used. 468 If `depencencies` is specified it should include pytorch and the verison has to match. 469 (`dependencies` overrules `pytorch_version`)""" 470 471 @field_validator("pytorch_version", mode="after") 472 @classmethod 473 def _ptv(cls, value: Any): 474 if value is None: 475 issue_warning( 476 "missing. Please specify the PyTorch version these" 477 + " PyTorch state dict weights were created with.", 478 value=value, 479 severity=ALERT, 480 field="pytorch_version", 481 ) 482 return value
callable returning a torch.nn.Module instance.
Local implementation: <relative path to file>:<identifier of implementation within the file>
.
Implementation in a dependency: <dependency-package>.<[dependency-module]>.<identifier>
.
The SHA256 of the architecture source file,
if the architecture is not defined in a module listed in dependencies
447 @model_validator(mode="after") 448 def check_architecture_sha256(self) -> Self: 449 if isinstance(self.architecture, CallableFromFile): 450 if self.architecture_sha256 is None: 451 raise ValueError( 452 "Missing required `architecture_sha256` for `architecture` with" 453 + " source file." 454 ) 455 elif self.architecture_sha256 is not None: 456 raise ValueError( 457 "Got `architecture_sha256` for architecture that does not have a source" 458 + " file." 459 ) 460 461 return self
Version of the PyTorch library used.
If depencencies
is specified it should include pytorch and the verison has to match.
(dependencies
overrules pytorch_version
)
Inherited Members
485class TorchscriptWeightsDescr(WeightsEntryDescrBase): 486 type = "torchscript" 487 weights_format_name: ClassVar[str] = "TorchScript" 488 pytorch_version: Optional[Version] = None 489 """Version of the PyTorch library used.""" 490 491 @field_validator("pytorch_version", mode="after") 492 @classmethod 493 def _ptv(cls, value: Any): 494 if value is None: 495 issue_warning( 496 "missing. Please specify the PyTorch version these" 497 + " Torchscript weights were created with.", 498 value=value, 499 severity=ALERT, 500 field="pytorch_version", 501 ) 502 return value
Version of the PyTorch library used.
Inherited Members
505class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 506 type = "tensorflow_js" 507 weights_format_name: ClassVar[str] = "Tensorflow.js" 508 tensorflow_version: Optional[Version] = None 509 """Version of the TensorFlow library used.""" 510 511 @field_validator("tensorflow_version", mode="after") 512 @classmethod 513 def _tfv(cls, value: Any): 514 if value is None: 515 issue_warning( 516 "missing. Please specify the TensorFlow version" 517 + " these TensorflowJs weights were created with.", 518 value=value, 519 severity=ALERT, 520 field="tensorflow_version", 521 ) 522 return value 523 524 source: ImportantFileSource 525 """The multi-file weights. 526 All required files/folders should be a zip archive."""
Version of the TensorFlow library used.
The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
529class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 530 type = "tensorflow_saved_model_bundle" 531 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 532 tensorflow_version: Optional[Version] = None 533 """Version of the TensorFlow library used.""" 534 535 @field_validator("tensorflow_version", mode="after") 536 @classmethod 537 def _tfv(cls, value: Any): 538 if value is None: 539 issue_warning( 540 "missing. Please specify the TensorFlow version" 541 + " these Tensorflow saved model bundle weights were created with.", 542 value=value, 543 severity=ALERT, 544 field="tensorflow_version", 545 ) 546 return value
Version of the TensorFlow library used.
Inherited Members
549class ParameterizedInputShape(Node): 550 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 551 552 min: NotEmpty[List[int]] 553 """The minimum input shape""" 554 555 step: NotEmpty[List[int]] 556 """The minimum shape change""" 557 558 def __len__(self) -> int: 559 return len(self.min) 560 561 @model_validator(mode="after") 562 def matching_lengths(self) -> Self: 563 if len(self.min) != len(self.step): 564 raise ValueError("`min` and `step` required to have the same length") 565 566 return self
A sequence of valid shapes given by shape_k = min + k * step for k in {0, 1, ...}
.
569class ImplicitOutputShape(Node): 570 """Output tensor shape depending on an input tensor shape. 571 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 572 573 reference_tensor: TensorName 574 """Name of the reference tensor.""" 575 576 scale: NotEmpty[List[Optional[float]]] 577 """output_pix/input_pix for each dimension. 578 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 579 580 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 581 """Position of origin wrt to input.""" 582 583 def __len__(self) -> int: 584 return len(self.scale) 585 586 @model_validator(mode="after") 587 def matching_lengths(self) -> Self: 588 if len(self.scale) != len(self.offset): 589 raise ValueError( 590 f"scale {self.scale} has to have same length as offset {self.offset}!" 591 ) 592 # if we have an expanded dimension, make sure that it's offet is not zero 593 for sc, off in zip(self.scale, self.offset): 594 if sc is None and not off: 595 raise ValueError("`offset` must not be zero if `scale` is none/zero") 596 597 return self
Output tensor shape depending on an input tensor shape.
shape(output_tensor) = shape(input_tensor) * scale + 2 * offset
output_pix/input_pix for each dimension.
'null' values indicate new dimensions, whose length is defined by 2*offset
Position of origin wrt to input.
586 @model_validator(mode="after") 587 def matching_lengths(self) -> Self: 588 if len(self.scale) != len(self.offset): 589 raise ValueError( 590 f"scale {self.scale} has to have same length as offset {self.offset}!" 591 ) 592 # if we have an expanded dimension, make sure that it's offet is not zero 593 for sc, off in zip(self.scale, self.offset): 594 if sc is None and not off: 595 raise ValueError("`offset` must not be zero if `scale` is none/zero") 596 597 return self
600class TensorDescrBase(Node): 601 name: TensorName 602 """Tensor name. No duplicates are allowed.""" 603 604 description: str = "" 605 606 axes: AxesStr 607 """Axes identifying characters. Same length and order as the axes in `shape`. 608 | axis | description | 609 | --- | --- | 610 | b | batch (groups multiple samples) | 611 | i | instance/index/element | 612 | t | time | 613 | c | channel | 614 | z | spatial dimension z | 615 | y | spatial dimension y | 616 | x | spatial dimension x | 617 """ 618 619 data_range: Optional[ 620 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 621 ] = None 622 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 623 If not specified, the full data range that can be expressed in `data_type` is allowed."""
Axes identifying characters. Same length and order as the axes in shape
.
axis | description |
---|---|
b | batch (groups multiple samples) |
i | instance/index/element |
t | time |
c | channel |
z | spatial dimension z |
y | spatial dimension y |
x | spatial dimension x |
626class ProcessingKwargs(KwargsNode): 627 """base class for pre-/postprocessing key word arguments"""
base class for pre-/postprocessing key word arguments
processing base class
634class BinarizeKwargs(ProcessingKwargs): 635 """key word arguments for `BinarizeDescr`""" 636 637 threshold: float 638 """The fixed threshold"""
key word arguments for BinarizeDescr
641class BinarizeDescr(ProcessingDescrBase): 642 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 643 Values above the threshold will be set to one, values below the threshold to zero. 644 """ 645 646 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 647 if TYPE_CHECKING: 648 name: Literal["binarize"] = "binarize" 649 else: 650 name: Literal["binarize"] 651 652 kwargs: BinarizeKwargs
BinarizeDescr the tensor with a fixed BinarizeKwargs.threshold
.
Values above the threshold will be set to one, values below the threshold to zero.
655class ClipKwargs(ProcessingKwargs): 656 """key word arguments for `ClipDescr`""" 657 658 min: float 659 """minimum value for clipping""" 660 max: float 661 """maximum value for clipping"""
key word arguments for ClipDescr
664class ClipDescr(ProcessingDescrBase): 665 """Clip tensor values to a range. 666 667 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 668 and above `ClipKwargs.max` to `ClipKwargs.max`. 669 """ 670 671 implemented_name: ClassVar[Literal["clip"]] = "clip" 672 if TYPE_CHECKING: 673 name: Literal["clip"] = "clip" 674 else: 675 name: Literal["clip"] 676 677 kwargs: ClipKwargs
Clip tensor values to a range.
Set tensor values below ClipKwargs.min
to ClipKwargs.min
and above ClipKwargs.max
to ClipKwargs.max
.
680class ScaleLinearKwargs(ProcessingKwargs): 681 """key word arguments for `ScaleLinearDescr`""" 682 683 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 684 """The subset of axes to scale jointly. 685 For example xy to scale the two image axes for 2d data jointly.""" 686 687 gain: Union[float, List[float]] = 1.0 688 """multiplicative factor""" 689 690 offset: Union[float, List[float]] = 0.0 691 """additive term""" 692 693 @model_validator(mode="after") 694 def either_gain_or_offset(self) -> Self: 695 if ( 696 self.gain == 1.0 697 or isinstance(self.gain, list) 698 and all(g == 1.0 for g in self.gain) 699 ) and ( 700 self.offset == 0.0 701 or isinstance(self.offset, list) 702 and all(off == 0.0 for off in self.offset) 703 ): 704 raise ValueError( 705 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 706 + " 0.0." 707 ) 708 709 return self
key word arguments for ScaleLinearDescr
The subset of axes to scale jointly. For example xy to scale the two image axes for 2d data jointly.
693 @model_validator(mode="after") 694 def either_gain_or_offset(self) -> Self: 695 if ( 696 self.gain == 1.0 697 or isinstance(self.gain, list) 698 and all(g == 1.0 for g in self.gain) 699 ) and ( 700 self.offset == 0.0 701 or isinstance(self.offset, list) 702 and all(off == 0.0 for off in self.offset) 703 ): 704 raise ValueError( 705 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 706 + " 0.0." 707 ) 708 709 return self
712class ScaleLinearDescr(ProcessingDescrBase): 713 """Fixed linear scaling.""" 714 715 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 716 if TYPE_CHECKING: 717 name: Literal["scale_linear"] = "scale_linear" 718 else: 719 name: Literal["scale_linear"] 720 721 kwargs: ScaleLinearKwargs
Fixed linear scaling.
724class SigmoidDescr(ProcessingDescrBase): 725 """The logistic sigmoid funciton, a.k.a. expit function.""" 726 727 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 728 if TYPE_CHECKING: 729 name: Literal["sigmoid"] = "sigmoid" 730 else: 731 name: Literal["sigmoid"] 732 733 @property 734 def kwargs(self) -> ProcessingKwargs: 735 """empty kwargs""" 736 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
733 @property 734 def kwargs(self) -> ProcessingKwargs: 735 """empty kwargs""" 736 return ProcessingKwargs()
empty kwargs
739class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 740 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 741 742 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 743 """Mode for computing mean and variance. 744 | mode | description | 745 | ----------- | ------------------------------------ | 746 | fixed | Fixed values for mean and variance | 747 | per_dataset | Compute for the entire dataset | 748 | per_sample | Compute for each sample individually | 749 """ 750 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 751 """The subset of axes to normalize jointly. 752 For example `xy` to normalize the two image axes for 2d data jointly.""" 753 754 mean: Annotated[ 755 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 756 ] = None 757 """The mean value(s) to use for `mode: fixed`. 758 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 759 # todo: check if means match input axes (for mode 'fixed') 760 761 std: Annotated[ 762 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 763 ] = None 764 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 765 766 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 767 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 768 769 @model_validator(mode="after") 770 def mean_and_std_match_mode(self) -> Self: 771 if self.mode == "fixed" and (self.mean is None or self.std is None): 772 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 773 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 774 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 775 776 return self
key word arguments for ZeroMeanUnitVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
fixed | Fixed values for mean and variance |
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to normalize jointly.
For example xy
to normalize the two image axes for 2d data jointly.
The mean value(s) to use for mode: fixed
.
For example [1.1, 2.2, 3.3]
in the case of a 3 channel image with axes: xy
.
The standard deviation values to use for mode: fixed
. Analogous to mean.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
769 @model_validator(mode="after") 770 def mean_and_std_match_mode(self) -> Self: 771 if self.mode == "fixed" and (self.mean is None or self.std is None): 772 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 773 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 774 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 775 776 return self
779class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 780 """Subtract mean and divide by variance.""" 781 782 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 783 "zero_mean_unit_variance" 784 ) 785 if TYPE_CHECKING: 786 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 787 else: 788 name: Literal["zero_mean_unit_variance"] 789 790 kwargs: ZeroMeanUnitVarianceKwargs
Subtract mean and divide by variance.
793class ScaleRangeKwargs(ProcessingKwargs): 794 """key word arguments for `ScaleRangeDescr` 795 796 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 797 this processing step normalizes data to the [0, 1] intervall. 798 For other percentiles the normalized values will partially be outside the [0, 1] 799 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 800 normalized values to a range. 801 """ 802 803 mode: Literal["per_dataset", "per_sample"] 804 """Mode for computing percentiles. 805 | mode | description | 806 | ----------- | ------------------------------------ | 807 | per_dataset | compute for the entire dataset | 808 | per_sample | compute for each sample individually | 809 """ 810 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 811 """The subset of axes to normalize jointly. 812 For example xy to normalize the two image axes for 2d data jointly.""" 813 814 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 815 """The lower percentile used to determine the value to align with zero.""" 816 817 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 818 """The upper percentile used to determine the value to align with one. 819 Has to be bigger than `min_percentile`. 820 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 821 accepting percentiles specified in the range 0.0 to 1.0.""" 822 823 @model_validator(mode="after") 824 def min_smaller_max(self, info: ValidationInfo) -> Self: 825 if self.min_percentile >= self.max_percentile: 826 raise ValueError( 827 f"min_percentile {self.min_percentile} >= max_percentile" 828 + f" {self.max_percentile}" 829 ) 830 831 return self 832 833 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 834 """Epsilon for numeric stability. 835 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 836 with `v_lower,v_upper` values at the respective percentiles.""" 837 838 reference_tensor: Optional[TensorName] = None 839 """Tensor name to compute the percentiles from. Default: The tensor itself. 840 For any tensor in `inputs` only input tensor references are allowed. 841 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
Mode for computing percentiles.
mode | description |
---|---|
per_dataset | compute for the entire dataset |
per_sample | compute for each sample individually |
The subset of axes to normalize jointly. For example xy to normalize the two image axes for 2d data jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor name to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
For a tensor in outputs
only input tensor refereences are allowed if mode: per_dataset
844class ScaleRangeDescr(ProcessingDescrBase): 845 """Scale with percentiles.""" 846 847 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 848 if TYPE_CHECKING: 849 name: Literal["scale_range"] = "scale_range" 850 else: 851 name: Literal["scale_range"] 852 853 kwargs: ScaleRangeKwargs
Scale with percentiles.
856class ScaleMeanVarianceKwargs(ProcessingKwargs): 857 """key word arguments for `ScaleMeanVarianceDescr`""" 858 859 mode: Literal["per_dataset", "per_sample"] 860 """Mode for computing mean and variance. 861 | mode | description | 862 | ----------- | ------------------------------------ | 863 | per_dataset | Compute for the entire dataset | 864 | per_sample | Compute for each sample individually | 865 """ 866 867 reference_tensor: TensorName 868 """Name of tensor to match.""" 869 870 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 871 """The subset of axes to scale jointly. 872 For example xy to normalize the two image axes for 2d data jointly. 873 Default: scale all non-batch axes jointly.""" 874 875 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 876 """Epsilon for numeric stability: 877 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
key word arguments for ScaleMeanVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to scale jointly. For example xy to normalize the two image axes for 2d data jointly. Default: scale all non-batch axes jointly.
880class ScaleMeanVarianceDescr(ProcessingDescrBase): 881 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 882 883 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 884 if TYPE_CHECKING: 885 name: Literal["scale_mean_variance"] = "scale_mean_variance" 886 else: 887 name: Literal["scale_mean_variance"] 888 889 kwargs: ScaleMeanVarianceKwargs
Scale the tensor s.t. its mean and variance match a reference tensor.
917class InputTensorDescr(TensorDescrBase): 918 data_type: Literal["float32", "uint8", "uint16"] 919 """For now an input tensor is expected to be given as `float32`. 920 The data flow in bioimage.io models is explained 921 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 922 923 shape: Annotated[ 924 Union[Sequence[int], ParameterizedInputShape], 925 Field( 926 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 927 ), 928 ] 929 """Specification of input tensor shape.""" 930 931 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 932 """Description of how this input should be preprocessed.""" 933 934 @model_validator(mode="after") 935 def zero_batch_step_and_one_batch_size(self) -> Self: 936 bidx = self.axes.find("b") 937 if bidx == -1: 938 return self 939 940 if isinstance(self.shape, ParameterizedInputShape): 941 step = self.shape.step 942 shape = self.shape.min 943 if step[bidx] != 0: 944 raise ValueError( 945 "Input shape step has to be zero in the batch dimension (the batch" 946 + " dimension can always be increased, but `step` should specify how" 947 + " to increase the minimal shape to find the largest single batch" 948 + " shape)" 949 ) 950 else: 951 shape = self.shape 952 953 if shape[bidx] != 1: 954 raise ValueError("Input shape has to be 1 in the batch dimension b.") 955 956 return self 957 958 @model_validator(mode="after") 959 def validate_preprocessing_kwargs(self) -> Self: 960 for p in self.preprocessing: 961 kwargs_axes = p.kwargs.get("axes") 962 if isinstance(kwargs_axes, str) and any( 963 a not in self.axes for a in kwargs_axes 964 ): 965 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 966 967 return self
For now an input tensor is expected to be given as float32
.
The data flow in bioimage.io models is explained
in this diagram..
Specification of input tensor shape.
Description of how this input should be preprocessed.
934 @model_validator(mode="after") 935 def zero_batch_step_and_one_batch_size(self) -> Self: 936 bidx = self.axes.find("b") 937 if bidx == -1: 938 return self 939 940 if isinstance(self.shape, ParameterizedInputShape): 941 step = self.shape.step 942 shape = self.shape.min 943 if step[bidx] != 0: 944 raise ValueError( 945 "Input shape step has to be zero in the batch dimension (the batch" 946 + " dimension can always be increased, but `step` should specify how" 947 + " to increase the minimal shape to find the largest single batch" 948 + " shape)" 949 ) 950 else: 951 shape = self.shape 952 953 if shape[bidx] != 1: 954 raise ValueError("Input shape has to be 1 in the batch dimension b.") 955 956 return self
958 @model_validator(mode="after") 959 def validate_preprocessing_kwargs(self) -> Self: 960 for p in self.preprocessing: 961 kwargs_axes = p.kwargs.get("axes") 962 if isinstance(kwargs_axes, str) and any( 963 a not in self.axes for a in kwargs_axes 964 ): 965 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 966 967 return self
Inherited Members
970class OutputTensorDescr(TensorDescrBase): 971 data_type: Literal[ 972 "float32", 973 "float64", 974 "uint8", 975 "int8", 976 "uint16", 977 "int16", 978 "uint32", 979 "int32", 980 "uint64", 981 "int64", 982 "bool", 983 ] 984 """Data type. 985 The data flow in bioimage.io models is explained 986 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 987 988 shape: Union[Sequence[int], ImplicitOutputShape] 989 """Output tensor shape.""" 990 991 halo: Optional[Sequence[int]] = None 992 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 993 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 994 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 995 996 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 997 """Description of how this output should be postprocessed.""" 998 999 @model_validator(mode="after") 1000 def matching_halo_length(self) -> Self: 1001 if self.halo and len(self.halo) != len(self.shape): 1002 raise ValueError( 1003 f"halo {self.halo} has to have same length as shape {self.shape}!" 1004 ) 1005 1006 return self 1007 1008 @model_validator(mode="after") 1009 def validate_postprocessing_kwargs(self) -> Self: 1010 for p in self.postprocessing: 1011 kwargs_axes = p.kwargs.get("axes", "") 1012 if not isinstance(kwargs_axes, str): 1013 raise ValueError(f"Expected {kwargs_axes} to be a string") 1014 1015 if any(a not in self.axes for a in kwargs_axes): 1016 raise ValueError("`kwargs.axes` needs to be subset of axes") 1017 1018 return self
Data type. The data flow in bioimage.io models is explained in this diagram..
Description of how this output should be postprocessed.
1008 @model_validator(mode="after") 1009 def validate_postprocessing_kwargs(self) -> Self: 1010 for p in self.postprocessing: 1011 kwargs_axes = p.kwargs.get("axes", "") 1012 if not isinstance(kwargs_axes, str): 1013 raise ValueError(f"Expected {kwargs_axes} to be a string") 1014 1015 if any(a not in self.axes for a in kwargs_axes): 1016 raise ValueError("`kwargs.axes` needs to be subset of axes") 1017 1018 return self
Inherited Members
1024class RunMode(Node): 1025 name: Annotated[ 1026 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 1027 ] 1028 """Run mode name""" 1029 1030 kwargs: Dict[str, Any] = Field(default_factory=dict) 1031 """Run mode specific key word arguments"""
1034class LinkedModel(Node): 1035 """Reference to a bioimage.io model.""" 1036 1037 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 1038 """A valid model `id` from the bioimage.io collection.""" 1039 1040 version_number: Optional[int] = None 1041 """version number (n-th published version, not the semantic version) of linked model"""
Reference to a bioimage.io model.
1044def package_weights( 1045 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 1046 handler: SerializerFunctionWrapHandler, 1047 info: SerializationInfo, 1048): 1049 ctxt = packaging_context_var.get() 1050 if ctxt is not None and ctxt.weights_priority_order is not None: 1051 for wf in ctxt.weights_priority_order: 1052 w = getattr(value, wf, None) 1053 if w is not None: 1054 break 1055 else: 1056 raise ValueError( 1057 "None of the weight formats in `weights_priority_order`" 1058 + f" ({ctxt.weights_priority_order}) is present in the given model." 1059 ) 1060 1061 assert isinstance(w, Node), type(w) 1062 # construct WeightsDescr with new single weight format entry 1063 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 1064 value = value.model_construct(None, **{wf: new_w}) 1065 1066 return handler( 1067 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 1068 )
1071class ModelDescr(GenericModelDescrBase): 1072 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 1073 1074 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 1075 """ 1076 1077 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 1078 if TYPE_CHECKING: 1079 format_version: Literal["0.4.10"] = "0.4.10" 1080 else: 1081 format_version: Literal["0.4.10"] 1082 """Version of the bioimage.io model description specification used. 1083 When creating a new model always use the latest micro/patch version described here. 1084 The `format_version` is important for any consumer software to understand how to parse the fields. 1085 """ 1086 1087 implemented_type: ClassVar[Literal["model"]] = "model" 1088 if TYPE_CHECKING: 1089 type: Literal["model"] = "model" 1090 else: 1091 type: Literal["model"] 1092 """Specialized resource type 'model'""" 1093 1094 id: Optional[ModelId] = None 1095 """bioimage.io-wide unique resource identifier 1096 assigned by bioimage.io; version **un**specific.""" 1097 1098 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 1099 List[Author] 1100 ] 1101 """The authors are the creators of the model RDF and the primary points of contact.""" 1102 1103 documentation: Annotated[ 1104 ImportantFileSource, 1105 Field( 1106 examples=[ 1107 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 1108 "README.md", 1109 ], 1110 ), 1111 ] 1112 """URL or relative path to a markdown file with additional documentation. 1113 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 1114 The documentation should include a '[#[#]]# Validation' (sub)section 1115 with details on how to quantitatively validate the model on unseen data.""" 1116 1117 inputs: NotEmpty[List[InputTensorDescr]] 1118 """Describes the input tensors expected by this model.""" 1119 1120 license: Annotated[ 1121 Union[LicenseId, str], 1122 warn(LicenseId, "Unknown license id '{value}'."), 1123 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 1124 ] 1125 """A [SPDX license identifier](https://spdx.org/licenses/). 1126 We do notsupport custom license beyond the SPDX license list, if you need that please 1127 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 1128 ) to discuss your intentions with the community.""" 1129 1130 name: Annotated[ 1131 str, 1132 MinLen(1), 1133 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 1134 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 1135 ] 1136 """A human-readable name of this model. 1137 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 1138 1139 outputs: NotEmpty[List[OutputTensorDescr]] 1140 """Describes the output tensors.""" 1141 1142 @field_validator("inputs", "outputs") 1143 @classmethod 1144 def unique_tensor_descr_names( 1145 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1146 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1147 unique_names = {str(v.name) for v in value} 1148 if len(unique_names) != len(value): 1149 raise ValueError("Duplicate tensor descriptor names") 1150 1151 return value 1152 1153 @model_validator(mode="after") 1154 def unique_io_names(self) -> Self: 1155 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1156 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1157 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1158 1159 return self 1160 1161 @model_validator(mode="after") 1162 def minimum_shape2valid_output(self) -> Self: 1163 tensors_by_name: Dict[ 1164 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1165 ] = {t.name: t for t in self.inputs + self.outputs} 1166 1167 for out in self.outputs: 1168 if isinstance(out.shape, ImplicitOutputShape): 1169 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1170 ndim_out_ref = len( 1171 [scale for scale in out.shape.scale if scale is not None] 1172 ) 1173 if ndim_ref != ndim_out_ref: 1174 expanded_dim_note = ( 1175 " Note that expanded dimensions (`scale`: null) are not" 1176 + f" counted for {out.name}'sdimensionality here." 1177 if None in out.shape.scale 1178 else "" 1179 ) 1180 raise ValueError( 1181 f"Referenced tensor '{out.shape.reference_tensor}' with" 1182 + f" {ndim_ref} dimensions does not match output tensor" 1183 + f" '{out.name}' with" 1184 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1185 ) 1186 1187 min_out_shape = self._get_min_shape(out, tensors_by_name) 1188 if out.halo: 1189 halo = out.halo 1190 halo_msg = f" for halo {out.halo}" 1191 else: 1192 halo = [0] * len(min_out_shape) 1193 halo_msg = "" 1194 1195 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1196 raise ValueError( 1197 f"Minimal shape {min_out_shape} of output {out.name} is too" 1198 + f" small{halo_msg}." 1199 ) 1200 1201 return self 1202 1203 @classmethod 1204 def _get_min_shape( 1205 cls, 1206 t: Union[InputTensorDescr, OutputTensorDescr], 1207 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1208 ) -> Sequence[int]: 1209 """output with subtracted halo has to result in meaningful output even for the minimal input 1210 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1211 """ 1212 if isinstance(t.shape, collections.abc.Sequence): 1213 return t.shape 1214 elif isinstance(t.shape, ParameterizedInputShape): 1215 return t.shape.min 1216 elif isinstance(t.shape, ImplicitOutputShape): 1217 pass 1218 else: 1219 assert_never(t.shape) 1220 1221 ref_shape = cls._get_min_shape( 1222 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1223 ) 1224 1225 if None not in t.shape.scale: 1226 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1227 else: 1228 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1229 new_ref_shape: List[int] = [] 1230 for idx in range(len(t.shape.scale)): 1231 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1232 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1233 1234 ref_shape = new_ref_shape 1235 assert len(ref_shape) == len(t.shape.scale) 1236 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1237 1238 offset = t.shape.offset 1239 assert len(offset) == len(scale) 1240 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1241 1242 @model_validator(mode="after") 1243 def validate_tensor_references_in_inputs(self) -> Self: 1244 for t in self.inputs: 1245 for proc in t.preprocessing: 1246 if "reference_tensor" not in proc.kwargs: 1247 continue 1248 1249 ref_tensor = proc.kwargs["reference_tensor"] 1250 if ref_tensor is not None and str(ref_tensor) not in { 1251 str(t.name) for t in self.inputs 1252 }: 1253 raise ValueError(f"'{ref_tensor}' not found in inputs") 1254 1255 if ref_tensor == t.name: 1256 raise ValueError( 1257 f"invalid self reference for preprocessing of tensor {t.name}" 1258 ) 1259 1260 return self 1261 1262 @model_validator(mode="after") 1263 def validate_tensor_references_in_outputs(self) -> Self: 1264 for t in self.outputs: 1265 for proc in t.postprocessing: 1266 if "reference_tensor" not in proc.kwargs: 1267 continue 1268 ref_tensor = proc.kwargs["reference_tensor"] 1269 if ref_tensor is not None and str(ref_tensor) not in { 1270 str(t.name) for t in self.inputs 1271 }: 1272 raise ValueError(f"{ref_tensor} not found in inputs") 1273 1274 return self 1275 1276 packaged_by: List[Author] = Field(default_factory=list) 1277 """The persons that have packaged and uploaded this model. 1278 Only required if those persons differ from the `authors`.""" 1279 1280 parent: Optional[LinkedModel] = None 1281 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1282 1283 @field_validator("parent", mode="before") 1284 @classmethod 1285 def ignore_url_parent(cls, parent: Any): 1286 if isinstance(parent, dict): 1287 return None 1288 1289 else: 1290 return parent 1291 1292 run_mode: Optional[RunMode] = None 1293 """Custom run mode for this model: for more complex prediction procedures like test time 1294 data augmentation that currently cannot be expressed in the specification. 1295 No standard run modes are defined yet.""" 1296 1297 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1298 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1299 for example stored as PNG or TIFF images. 1300 The sample files primarily serve to inform a human user about an example use case""" 1301 1302 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1303 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1304 1305 test_inputs: NotEmpty[ 1306 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1307 ] 1308 """Test input tensors compatible with the `inputs` description for a **single test case**. 1309 This means if your model has more than one input, you should provide one URL/relative path for each input. 1310 Each test input should be a file with an ndarray in 1311 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1312 The extension must be '.npy'.""" 1313 1314 test_outputs: NotEmpty[ 1315 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1316 ] 1317 """Analog to `test_inputs`.""" 1318 1319 timestamp: Datetime 1320 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1321 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1322 1323 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1324 """The dataset used to train this model""" 1325 1326 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1327 """The weights for this model. 1328 Weights can be given for different formats, but should otherwise be equivalent. 1329 The available weight formats determine which consumers can use this model.""" 1330 1331 @model_validator(mode="before") 1332 @classmethod 1333 def _convert_from_older_format( 1334 cls, data: BioimageioYamlContent, / 1335 ) -> BioimageioYamlContent: 1336 convert_from_older_format(data) 1337 return data 1338 1339 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1340 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1341 assert all(isinstance(d, np.ndarray) for d in data) 1342 return data 1343 1344 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1345 data = [load_array(download(out).path) for out in self.test_outputs] 1346 assert all(isinstance(d, np.ndarray) for d in data) 1347 return data
Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '[#[#]]# Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A SPDX license identifier. We do notsupport custom license beyond the SPDX license list, if you need that please open a GitHub issue to discuss your intentions with the community.
A human-readable name of this model. It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.
1142 @field_validator("inputs", "outputs") 1143 @classmethod 1144 def unique_tensor_descr_names( 1145 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 1146 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 1147 unique_names = {str(v.name) for v in value} 1148 if len(unique_names) != len(value): 1149 raise ValueError("Duplicate tensor descriptor names") 1150 1151 return value
1153 @model_validator(mode="after") 1154 def unique_io_names(self) -> Self: 1155 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1156 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1157 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1158 1159 return self
1161 @model_validator(mode="after") 1162 def minimum_shape2valid_output(self) -> Self: 1163 tensors_by_name: Dict[ 1164 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1165 ] = {t.name: t for t in self.inputs + self.outputs} 1166 1167 for out in self.outputs: 1168 if isinstance(out.shape, ImplicitOutputShape): 1169 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1170 ndim_out_ref = len( 1171 [scale for scale in out.shape.scale if scale is not None] 1172 ) 1173 if ndim_ref != ndim_out_ref: 1174 expanded_dim_note = ( 1175 " Note that expanded dimensions (`scale`: null) are not" 1176 + f" counted for {out.name}'sdimensionality here." 1177 if None in out.shape.scale 1178 else "" 1179 ) 1180 raise ValueError( 1181 f"Referenced tensor '{out.shape.reference_tensor}' with" 1182 + f" {ndim_ref} dimensions does not match output tensor" 1183 + f" '{out.name}' with" 1184 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1185 ) 1186 1187 min_out_shape = self._get_min_shape(out, tensors_by_name) 1188 if out.halo: 1189 halo = out.halo 1190 halo_msg = f" for halo {out.halo}" 1191 else: 1192 halo = [0] * len(min_out_shape) 1193 halo_msg = "" 1194 1195 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1196 raise ValueError( 1197 f"Minimal shape {min_out_shape} of output {out.name} is too" 1198 + f" small{halo_msg}." 1199 ) 1200 1201 return self
1242 @model_validator(mode="after") 1243 def validate_tensor_references_in_inputs(self) -> Self: 1244 for t in self.inputs: 1245 for proc in t.preprocessing: 1246 if "reference_tensor" not in proc.kwargs: 1247 continue 1248 1249 ref_tensor = proc.kwargs["reference_tensor"] 1250 if ref_tensor is not None and str(ref_tensor) not in { 1251 str(t.name) for t in self.inputs 1252 }: 1253 raise ValueError(f"'{ref_tensor}' not found in inputs") 1254 1255 if ref_tensor == t.name: 1256 raise ValueError( 1257 f"invalid self reference for preprocessing of tensor {t.name}" 1258 ) 1259 1260 return self
1262 @model_validator(mode="after") 1263 def validate_tensor_references_in_outputs(self) -> Self: 1264 for t in self.outputs: 1265 for proc in t.postprocessing: 1266 if "reference_tensor" not in proc.kwargs: 1267 continue 1268 ref_tensor = proc.kwargs["reference_tensor"] 1269 if ref_tensor is not None and str(ref_tensor) not in { 1270 str(t.name) for t in self.inputs 1271 }: 1272 raise ValueError(f"{ref_tensor} not found in inputs") 1273 1274 return self
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
URLs/relative paths to sample inputs to illustrate possible inputs for the model, for example stored as PNG or TIFF images. The sample files primarily serve to inform a human user about an example use case
URLs/relative paths to sample outputs corresponding to the sample_inputs
.
Test input tensors compatible with the inputs
description for a single test case.
This means if your model has more than one input, you should provide one URL/relative path for each input.
Each test input should be a file with an ndarray in
numpy.lib file format.
The extension must be '.npy'.
Analog to test_inputs
.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.