bioimageio.spec.model.v0_4
1from __future__ import annotations 2 3import collections.abc 4from typing import ( 5 Any, 6 ClassVar, 7 Dict, 8 FrozenSet, 9 List, 10 Literal, 11 Optional, 12 Sequence, 13 Tuple, 14 Union, 15) 16 17import numpy as np 18from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 19from numpy.typing import NDArray 20from pydantic import ( 21 AllowInfNan, 22 Discriminator, 23 Field, 24 SerializationInfo, 25 SerializerFunctionWrapHandler, 26 TypeAdapter, 27 ValidationInfo, 28 WrapSerializer, 29 field_validator, 30 model_validator, 31) 32from typing_extensions import Annotated, LiteralString, Self, assert_never 33 34from .._internal.common_nodes import ( 35 KwargsNode, 36 Node, 37 NodeWithExplicitlySetFields, 38 StringNode, 39) 40from .._internal.constants import SHA256_HINT 41from .._internal.field_validation import validate_unique_entries 42from .._internal.field_warning import issue_warning, warn 43from .._internal.io import ( 44 BioimageioYamlContent, 45 WithSuffix, 46 download, 47 include_in_package_serializer, 48) 49from .._internal.io import FileDescr as FileDescr 50from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 51from .._internal.io_basics import Sha256 as Sha256 52from .._internal.io_utils import load_array 53from .._internal.packaging_context import packaging_context_var 54from .._internal.types import Datetime as Datetime 55from .._internal.types import Identifier as Identifier 56from .._internal.types import ImportantFileSource, LowerCaseIdentifier 57from .._internal.types import LicenseId as LicenseId 58from .._internal.types import NotEmpty as NotEmpty 59from .._internal.url import HttpUrl as HttpUrl 60from .._internal.validator_annotations import AfterValidator, RestrictCharacters 61from .._internal.version_type import Version as Version 62from .._internal.warning_levels import ALERT, INFO 63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 64from ..dataset.v0_2 import DatasetDescr as DatasetDescr 65from ..dataset.v0_2 import LinkedDataset as LinkedDataset 66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 67from ..generic.v0_2 import Author as Author 68from ..generic.v0_2 import BadgeDescr as BadgeDescr 69from ..generic.v0_2 import CiteEntry as CiteEntry 70from ..generic.v0_2 import Doi as Doi 71from ..generic.v0_2 import GenericModelDescrBase 72from ..generic.v0_2 import LinkedResource as LinkedResource 73from ..generic.v0_2 import Maintainer as Maintainer 74from ..generic.v0_2 import OrcidId as OrcidId 75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 76from ..generic.v0_2 import ResourceId as ResourceId 77from ..generic.v0_2 import Uploader as Uploader 78from ._v0_4_converter import convert_from_older_format 79 80 81class ModelId(ResourceId): 82 pass 83 84 85AxesStr = Annotated[ 86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 87] 88AxesInCZYX = Annotated[ 89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 90] 91 92PostprocessingName = Literal[ 93 "binarize", 94 "clip", 95 "scale_linear", 96 "sigmoid", 97 "zero_mean_unit_variance", 98 "scale_range", 99 "scale_mean_variance", 100] 101PreprocessingName = Literal[ 102 "binarize", 103 "clip", 104 "scale_linear", 105 "sigmoid", 106 "zero_mean_unit_variance", 107 "scale_range", 108] 109 110 111class TensorName(LowerCaseIdentifier): 112 pass 113 114 115class CallableFromDepencency(StringNode): 116 _pattern = r"^.+\..+$" 117 _submodule_adapter = TypeAdapter(Identifier) 118 119 module_name: str 120 121 @field_validator("module_name", mode="after") 122 def check_submodules(cls, module_name: str) -> str: 123 for submod in module_name.split("."): 124 _ = cls._submodule_adapter.validate_python(submod) 125 126 return module_name 127 128 callable_name: Identifier 129 130 @classmethod 131 def _get_data(cls, valid_string_data: str): 132 *mods, callname = valid_string_data.split(".") 133 return dict(module_name=".".join(mods), callable_name=callname) 134 135 136class CallableFromFile(StringNode): 137 _pattern = r"^.+:.+$" 138 source_file: Annotated[ 139 Union[RelativeFilePath, HttpUrl], 140 Field(union_mode="left_to_right"), 141 include_in_package_serializer, 142 ] 143 """∈📦 Python module that implements `callable_name`""" 144 callable_name: Identifier 145 """The Python identifier of """ 146 147 @classmethod 148 def _get_data(cls, valid_string_data: str): 149 *file_parts, callname = valid_string_data.split(":") 150 return dict(source_file=":".join(file_parts), callable_name=callname) 151 152 153CustomCallable = Annotated[ 154 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 155] 156 157 158class Dependencies(StringNode): 159 _pattern = r"^.+:.+$" 160 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 161 """Dependency manager""" 162 163 file: Annotated[ 164 ImportantFileSource, 165 Field(examples=["environment.yaml", "pom.xml", "requirements.txt"]), 166 ] 167 """∈📦 Dependency file""" 168 169 @classmethod 170 def _get_data(cls, valid_string_data: str): 171 manager, *file_parts = valid_string_data.split(":") 172 return dict(manager=manager, file=":".join(file_parts)) 173 174 175WeightsFormat = Literal[ 176 "keras_hdf5", 177 "onnx", 178 "pytorch_state_dict", 179 "tensorflow_js", 180 "tensorflow_saved_model_bundle", 181 "torchscript", 182] 183 184 185class WeightsDescr(Node): 186 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 187 onnx: Optional[OnnxWeightsDescr] = None 188 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 189 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 190 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 191 None 192 ) 193 torchscript: Optional[TorchscriptWeightsDescr] = None 194 195 @model_validator(mode="after") 196 def check_one_entry(self) -> Self: 197 if all( 198 entry is None 199 for entry in [ 200 self.keras_hdf5, 201 self.onnx, 202 self.pytorch_state_dict, 203 self.tensorflow_js, 204 self.tensorflow_saved_model_bundle, 205 self.torchscript, 206 ] 207 ): 208 raise ValueError("Missing weights entry") 209 210 return self 211 212 213class WeightsEntryDescrBase(FileDescr): 214 type: ClassVar[WeightsFormat] 215 weights_format_name: ClassVar[str] # human readable 216 217 source: ImportantFileSource 218 """∈📦 The weights file.""" 219 220 attachments: Annotated[ 221 Union[AttachmentsDescr, None], 222 warn(None, "Weights entry depends on additional attachments.", ALERT), 223 ] = None 224 """Attachments that are specific to this weights entry.""" 225 226 authors: Union[List[Author], None] = None 227 """Authors 228 Either the person(s) that have trained this model resulting in the original weights file. 229 (If this is the initial weights entry, i.e. it does not have a `parent`) 230 Or the person(s) who have converted the weights to this weights format. 231 (If this is a child weight, i.e. it has a `parent` field) 232 """ 233 234 dependencies: Annotated[ 235 Optional[Dependencies], 236 warn( 237 None, 238 "Custom dependencies ({value}) specified. Avoid this whenever possible " 239 + "to allow execution in a wider range of software environments.", 240 ), 241 Field( 242 examples=[ 243 "conda:environment.yaml", 244 "maven:./pom.xml", 245 "pip:./requirements.txt", 246 ] 247 ), 248 ] = None 249 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 250 251 parent: Annotated[ 252 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 253 ] = None 254 """The source weights these weights were converted from. 255 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 256 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 257 All weight entries except one (the initial set of weights resulting from training the model), 258 need to have this field.""" 259 260 @model_validator(mode="after") 261 def check_parent_is_not_self(self) -> Self: 262 if self.type == self.parent: 263 raise ValueError("Weights entry can't be it's own parent.") 264 265 return self 266 267 268class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 269 type = "keras_hdf5" 270 weights_format_name: ClassVar[str] = "Keras HDF5" 271 tensorflow_version: Optional[Version] = None 272 """TensorFlow version used to create these weights""" 273 274 @field_validator("tensorflow_version", mode="after") 275 @classmethod 276 def _tfv(cls, value: Any): 277 if value is None: 278 issue_warning( 279 "missing. Please specify the TensorFlow version" 280 + " these weights were created with.", 281 value=value, 282 severity=ALERT, 283 field="tensorflow_version", 284 ) 285 return value 286 287 288class OnnxWeightsDescr(WeightsEntryDescrBase): 289 type = "onnx" 290 weights_format_name: ClassVar[str] = "ONNX" 291 opset_version: Optional[Annotated[int, Ge(7)]] = None 292 """ONNX opset version""" 293 294 @field_validator("opset_version", mode="after") 295 @classmethod 296 def _ov(cls, value: Any): 297 if value is None: 298 issue_warning( 299 "Missing ONNX opset version (aka ONNX opset number). " 300 + "Please specify the ONNX opset version these weights were created" 301 + " with.", 302 value=value, 303 severity=ALERT, 304 field="opset_version", 305 ) 306 return value 307 308 309class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 310 type = "pytorch_state_dict" 311 weights_format_name: ClassVar[str] = "Pytorch State Dict" 312 architecture: CustomCallable = Field( 313 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 314 ) 315 """callable returning a torch.nn.Module instance. 316 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 317 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 318 319 architecture_sha256: Annotated[ 320 Optional[Sha256], 321 Field( 322 description=( 323 "The SHA256 of the architecture source file, if the architecture is not" 324 " defined in a module listed in `dependencies`\n" 325 ) 326 + SHA256_HINT, 327 ), 328 ] = None 329 """The SHA256 of the architecture source file, 330 if the architecture is not defined in a module listed in `dependencies`""" 331 332 @model_validator(mode="after") 333 def check_architecture_sha256(self) -> Self: 334 if isinstance(self.architecture, CallableFromFile): 335 if self.architecture_sha256 is None: 336 raise ValueError( 337 "Missing required `architecture_sha256` for `architecture` with" 338 + " source file." 339 ) 340 elif self.architecture_sha256 is not None: 341 raise ValueError( 342 "Got `architecture_sha256` for architecture that does not have a source" 343 + " file." 344 ) 345 346 return self 347 348 kwargs: Dict[str, Any] = Field(default_factory=dict) 349 """key word arguments for the `architecture` callable""" 350 351 pytorch_version: Optional[Version] = None 352 """Version of the PyTorch library used. 353 If `depencencies` is specified it should include pytorch and the verison has to match. 354 (`dependencies` overrules `pytorch_version`)""" 355 356 @field_validator("pytorch_version", mode="after") 357 @classmethod 358 def _ptv(cls, value: Any): 359 if value is None: 360 issue_warning( 361 "missing. Please specify the PyTorch version these" 362 + " PyTorch state dict weights were created with.", 363 value=value, 364 severity=ALERT, 365 field="pytorch_version", 366 ) 367 return value 368 369 370class TorchscriptWeightsDescr(WeightsEntryDescrBase): 371 type = "torchscript" 372 weights_format_name: ClassVar[str] = "TorchScript" 373 pytorch_version: Optional[Version] = None 374 """Version of the PyTorch library used.""" 375 376 @field_validator("pytorch_version", mode="after") 377 @classmethod 378 def _ptv(cls, value: Any): 379 if value is None: 380 issue_warning( 381 "missing. Please specify the PyTorch version these" 382 + " Torchscript weights were created with.", 383 value=value, 384 severity=ALERT, 385 field="pytorch_version", 386 ) 387 return value 388 389 390class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 391 type = "tensorflow_js" 392 weights_format_name: ClassVar[str] = "Tensorflow.js" 393 tensorflow_version: Optional[Version] = None 394 """Version of the TensorFlow library used.""" 395 396 @field_validator("tensorflow_version", mode="after") 397 @classmethod 398 def _tfv(cls, value: Any): 399 if value is None: 400 issue_warning( 401 "missing. Please specify the TensorFlow version" 402 + " these TensorflowJs weights were created with.", 403 value=value, 404 severity=ALERT, 405 field="tensorflow_version", 406 ) 407 return value 408 409 source: ImportantFileSource 410 """∈📦 The multi-file weights. 411 All required files/folders should be a zip archive.""" 412 413 414class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 415 type = "tensorflow_saved_model_bundle" 416 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 417 tensorflow_version: Optional[Version] = None 418 """Version of the TensorFlow library used.""" 419 420 @field_validator("tensorflow_version", mode="after") 421 @classmethod 422 def _tfv(cls, value: Any): 423 if value is None: 424 issue_warning( 425 "missing. Please specify the TensorFlow version" 426 + " these Tensorflow saved model bundle weights were created with.", 427 value=value, 428 severity=ALERT, 429 field="tensorflow_version", 430 ) 431 return value 432 433 434class ParameterizedInputShape(Node): 435 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 436 437 min: NotEmpty[List[int]] 438 """The minimum input shape""" 439 440 step: NotEmpty[List[int]] 441 """The minimum shape change""" 442 443 def __len__(self) -> int: 444 return len(self.min) 445 446 @model_validator(mode="after") 447 def matching_lengths(self) -> Self: 448 if len(self.min) != len(self.step): 449 raise ValueError("`min` and `step` required to have the same length") 450 451 return self 452 453 454class ImplicitOutputShape(Node): 455 """Output tensor shape depending on an input tensor shape. 456 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 457 458 reference_tensor: TensorName 459 """Name of the reference tensor.""" 460 461 scale: NotEmpty[List[Optional[float]]] 462 """output_pix/input_pix for each dimension. 463 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 464 465 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 466 """Position of origin wrt to input.""" 467 468 def __len__(self) -> int: 469 return len(self.scale) 470 471 @model_validator(mode="after") 472 def matching_lengths(self) -> Self: 473 if len(self.scale) != len(self.offset): 474 raise ValueError( 475 f"scale {self.scale} has to have same length as offset {self.offset}!" 476 ) 477 # if we have an expanded dimension, make sure that it's offet is not zero 478 for sc, off in zip(self.scale, self.offset): 479 if sc is None and not off: 480 raise ValueError("`offset` must not be zero if `scale` is none/zero") 481 482 return self 483 484 485class TensorDescrBase(Node): 486 name: TensorName 487 """Tensor name. No duplicates are allowed.""" 488 489 description: str = "" 490 491 axes: AxesStr 492 """Axes identifying characters. Same length and order as the axes in `shape`. 493 | axis | description | 494 | --- | --- | 495 | b | batch (groups multiple samples) | 496 | i | instance/index/element | 497 | t | time | 498 | c | channel | 499 | z | spatial dimension z | 500 | y | spatial dimension y | 501 | x | spatial dimension x | 502 """ 503 504 data_range: Optional[ 505 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 506 ] = None 507 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 508 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 509 510 511class ProcessingKwargs(KwargsNode): 512 """base class for pre-/postprocessing key word arguments""" 513 514 515class ProcessingDescrBase(NodeWithExplicitlySetFields): 516 """processing base class""" 517 518 # name: Literal[PreprocessingName, PostprocessingName] # todo: make abstract field 519 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"name"}) 520 521 522class BinarizeKwargs(ProcessingKwargs): 523 """key word arguments for `BinarizeDescr`""" 524 525 threshold: float 526 """The fixed threshold""" 527 528 529class BinarizeDescr(ProcessingDescrBase): 530 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 531 Values above the threshold will be set to one, values below the threshold to zero. 532 """ 533 534 name: Literal["binarize"] = "binarize" 535 kwargs: BinarizeKwargs 536 537 538class ClipKwargs(ProcessingKwargs): 539 """key word arguments for `ClipDescr`""" 540 541 min: float 542 """minimum value for clipping""" 543 max: float 544 """maximum value for clipping""" 545 546 547class ClipDescr(ProcessingDescrBase): 548 """Clip tensor values to a range. 549 550 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 551 and above `ClipKwargs.max` to `ClipKwargs.max`. 552 """ 553 554 name: Literal["clip"] = "clip" 555 556 kwargs: ClipKwargs 557 558 559class ScaleLinearKwargs(ProcessingKwargs): 560 """key word arguments for `ScaleLinearDescr`""" 561 562 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 563 """The subset of axes to scale jointly. 564 For example xy to scale the two image axes for 2d data jointly.""" 565 566 gain: Union[float, List[float]] = 1.0 567 """multiplicative factor""" 568 569 offset: Union[float, List[float]] = 0.0 570 """additive term""" 571 572 @model_validator(mode="after") 573 def either_gain_or_offset(self) -> Self: 574 if ( 575 self.gain == 1.0 576 or isinstance(self.gain, list) 577 and all(g == 1.0 for g in self.gain) 578 ) and ( 579 self.offset == 0.0 580 or isinstance(self.offset, list) 581 and all(off == 0.0 for off in self.offset) 582 ): 583 raise ValueError( 584 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 585 + " 0.0." 586 ) 587 588 return self 589 590 591class ScaleLinearDescr(ProcessingDescrBase): 592 """Fixed linear scaling.""" 593 594 name: Literal["scale_linear"] = "scale_linear" 595 kwargs: ScaleLinearKwargs 596 597 598class SigmoidDescr(ProcessingDescrBase): 599 """The logistic sigmoid funciton, a.k.a. expit function.""" 600 601 name: Literal["sigmoid"] = "sigmoid" 602 603 @property 604 def kwargs(self) -> ProcessingKwargs: 605 """empty kwargs""" 606 return ProcessingKwargs() 607 608 609class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 610 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 611 612 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 613 """Mode for computing mean and variance. 614 | mode | description | 615 | ----------- | ------------------------------------ | 616 | fixed | Fixed values for mean and variance | 617 | per_dataset | Compute for the entire dataset | 618 | per_sample | Compute for each sample individually | 619 """ 620 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 621 """The subset of axes to normalize jointly. 622 For example `xy` to normalize the two image axes for 2d data jointly.""" 623 624 mean: Annotated[ 625 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 626 ] = None 627 """The mean value(s) to use for `mode: fixed`. 628 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 629 # todo: check if means match input axes (for mode 'fixed') 630 631 std: Annotated[ 632 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 633 ] = None 634 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 635 636 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 637 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 638 639 @model_validator(mode="after") 640 def mean_and_std_match_mode(self) -> Self: 641 if self.mode == "fixed" and (self.mean is None or self.std is None): 642 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 643 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 644 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 645 646 return self 647 648 649class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 650 """Subtract mean and divide by variance.""" 651 652 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 653 kwargs: ZeroMeanUnitVarianceKwargs 654 655 656class ScaleRangeKwargs(ProcessingKwargs): 657 """key word arguments for `ScaleRangeDescr` 658 659 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 660 this processing step normalizes data to the [0, 1] intervall. 661 For other percentiles the normalized values will partially be outside the [0, 1] 662 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 663 normalized values to a range. 664 """ 665 666 mode: Literal["per_dataset", "per_sample"] 667 """Mode for computing percentiles. 668 | mode | description | 669 | ----------- | ------------------------------------ | 670 | per_dataset | compute for the entire dataset | 671 | per_sample | compute for each sample individually | 672 """ 673 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 674 """The subset of axes to normalize jointly. 675 For example xy to normalize the two image axes for 2d data jointly.""" 676 677 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 678 """The lower percentile used to determine the value to align with zero.""" 679 680 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 681 """The upper percentile used to determine the value to align with one. 682 Has to be bigger than `min_percentile`. 683 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 684 accepting percentiles specified in the range 0.0 to 1.0.""" 685 686 @model_validator(mode="after") 687 def min_smaller_max(self, info: ValidationInfo) -> Self: 688 if self.min_percentile >= self.max_percentile: 689 raise ValueError( 690 f"min_percentile {self.min_percentile} >= max_percentile" 691 + f" {self.max_percentile}" 692 ) 693 694 return self 695 696 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 697 """Epsilon for numeric stability. 698 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 699 with `v_lower,v_upper` values at the respective percentiles.""" 700 701 reference_tensor: Optional[TensorName] = None 702 """Tensor name to compute the percentiles from. Default: The tensor itself. 703 For any tensor in `inputs` only input tensor references are allowed. 704 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 705 706 707class ScaleRangeDescr(ProcessingDescrBase): 708 """Scale with percentiles.""" 709 710 name: Literal["scale_range"] = "scale_range" 711 kwargs: ScaleRangeKwargs 712 713 714class ScaleMeanVarianceKwargs(ProcessingKwargs): 715 """key word arguments for `ScaleMeanVarianceDescr`""" 716 717 mode: Literal["per_dataset", "per_sample"] 718 """Mode for computing mean and variance. 719 | mode | description | 720 | ----------- | ------------------------------------ | 721 | per_dataset | Compute for the entire dataset | 722 | per_sample | Compute for each sample individually | 723 """ 724 725 reference_tensor: TensorName 726 """Name of tensor to match.""" 727 728 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 729 """The subset of axes to scale jointly. 730 For example xy to normalize the two image axes for 2d data jointly. 731 Default: scale all non-batch axes jointly.""" 732 733 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 734 """Epsilon for numeric stability: 735 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 736 737 738class ScaleMeanVarianceDescr(ProcessingDescrBase): 739 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 740 741 name: Literal["scale_mean_variance"] = "scale_mean_variance" 742 kwargs: ScaleMeanVarianceKwargs 743 744 745PreprocessingDescr = Annotated[ 746 Union[ 747 BinarizeDescr, 748 ClipDescr, 749 ScaleLinearDescr, 750 SigmoidDescr, 751 ZeroMeanUnitVarianceDescr, 752 ScaleRangeDescr, 753 ], 754 Discriminator("name"), 755] 756PostprocessingDescr = Annotated[ 757 Union[ 758 BinarizeDescr, 759 ClipDescr, 760 ScaleLinearDescr, 761 SigmoidDescr, 762 ZeroMeanUnitVarianceDescr, 763 ScaleRangeDescr, 764 ScaleMeanVarianceDescr, 765 ], 766 Discriminator("name"), 767] 768 769 770class InputTensorDescr(TensorDescrBase): 771 data_type: Literal["float32", "uint8", "uint16"] 772 """For now an input tensor is expected to be given as `float32`. 773 The data flow in bioimage.io models is explained 774 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 775 776 shape: Annotated[ 777 Union[Sequence[int], ParameterizedInputShape], 778 Field( 779 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 780 ), 781 ] 782 """Specification of input tensor shape.""" 783 784 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 785 """Description of how this input should be preprocessed.""" 786 787 @model_validator(mode="after") 788 def zero_batch_step_and_one_batch_size(self) -> Self: 789 bidx = self.axes.find("b") 790 if bidx == -1: 791 return self 792 793 if isinstance(self.shape, ParameterizedInputShape): 794 step = self.shape.step 795 shape = self.shape.min 796 if step[bidx] != 0: 797 raise ValueError( 798 "Input shape step has to be zero in the batch dimension (the batch" 799 + " dimension can always be increased, but `step` should specify how" 800 + " to increase the minimal shape to find the largest single batch" 801 + " shape)" 802 ) 803 else: 804 shape = self.shape 805 806 if shape[bidx] != 1: 807 raise ValueError("Input shape has to be 1 in the batch dimension b.") 808 809 return self 810 811 @model_validator(mode="after") 812 def validate_preprocessing_kwargs(self) -> Self: 813 for p in self.preprocessing: 814 kwargs_axes = p.kwargs.get("axes", "") 815 if not isinstance(kwargs_axes, str): 816 raise ValueError( 817 f"Expected an `axes` string, but got {type(kwargs_axes)}" 818 ) 819 820 if any(a not in self.axes for a in kwargs_axes): 821 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 822 823 return self 824 825 826class OutputTensorDescr(TensorDescrBase): 827 data_type: Literal[ 828 "float32", 829 "float64", 830 "uint8", 831 "int8", 832 "uint16", 833 "int16", 834 "uint32", 835 "int32", 836 "uint64", 837 "int64", 838 "bool", 839 ] 840 """Data type. 841 The data flow in bioimage.io models is explained 842 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 843 844 shape: Union[Sequence[int], ImplicitOutputShape] 845 """Output tensor shape.""" 846 847 halo: Optional[Sequence[int]] = None 848 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 849 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 850 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 851 852 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 853 """Description of how this output should be postprocessed.""" 854 855 @model_validator(mode="after") 856 def matching_halo_length(self) -> Self: 857 if self.halo and len(self.halo) != len(self.shape): 858 raise ValueError( 859 f"halo {self.halo} has to have same length as shape {self.shape}!" 860 ) 861 862 return self 863 864 @model_validator(mode="after") 865 def validate_postprocessing_kwargs(self) -> Self: 866 for p in self.postprocessing: 867 kwargs_axes = p.kwargs.get("axes", "") 868 if not isinstance(kwargs_axes, str): 869 raise ValueError(f"Expected {kwargs_axes} to be a string") 870 871 if any(a not in self.axes for a in kwargs_axes): 872 raise ValueError("`kwargs.axes` needs to be subset of axes") 873 874 return self 875 876 877KnownRunMode = Literal["deepimagej"] 878 879 880class RunMode(Node): 881 name: Annotated[ 882 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 883 ] 884 """Run mode name""" 885 886 kwargs: Dict[str, Any] = Field(default_factory=dict) 887 """Run mode specific key word arguments""" 888 889 890class LinkedModel(Node): 891 """Reference to a bioimage.io model.""" 892 893 id: ModelId 894 """A valid model `id` from the bioimage.io collection.""" 895 896 version_number: Optional[int] = None 897 """version number (n-th published version, not the semantic version) of linked model""" 898 899 900def package_weights( 901 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 902 handler: SerializerFunctionWrapHandler, 903 info: SerializationInfo, 904): 905 ctxt = packaging_context_var.get() 906 if ctxt is not None and ctxt.weights_priority_order is not None: 907 for wf in ctxt.weights_priority_order: 908 w = getattr(value, wf, None) 909 if w is not None: 910 break 911 else: 912 raise ValueError( 913 "None of the weight formats in `weights_priority_order`" 914 + f" ({ctxt.weights_priority_order}) is present in the given model." 915 ) 916 917 assert isinstance(w, Node), type(w) 918 # construct WeightsDescr with new single weight format entry 919 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 920 value = value.model_construct(None, **{wf: new_w}) 921 922 return handler( 923 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 924 ) 925 926 927class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"): 928 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 929 930 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 931 """ 932 933 format_version: Literal["0.4.10",] = "0.4.10" 934 """Version of the bioimage.io model description specification used. 935 When creating a new model always use the latest micro/patch version described here. 936 The `format_version` is important for any consumer software to understand how to parse the fields. 937 """ 938 939 type: Literal["model"] = "model" 940 """Specialized resource type 'model'""" 941 942 id: Optional[ModelId] = None 943 """bioimage.io-wide unique resource identifier 944 assigned by bioimage.io; version **un**specific.""" 945 946 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 947 List[Author] 948 ] 949 """The authors are the creators of the model RDF and the primary points of contact.""" 950 951 documentation: Annotated[ 952 ImportantFileSource, 953 Field( 954 examples=[ 955 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 956 "README.md", 957 ], 958 ), 959 ] 960 """∈📦 URL or relative path to a markdown file with additional documentation. 961 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 962 The documentation should include a '[#[#]]# Validation' (sub)section 963 with details on how to quantitatively validate the model on unseen data.""" 964 965 inputs: NotEmpty[List[InputTensorDescr]] 966 """Describes the input tensors expected by this model.""" 967 968 license: Annotated[ 969 Union[LicenseId, str], 970 warn(LicenseId, "Unknown license id '{value}'."), 971 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 972 ] 973 """A [SPDX license identifier](https://spdx.org/licenses/). 974 We do notsupport custom license beyond the SPDX license list, if you need that please 975 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 976 ) to discuss your intentions with the community.""" 977 978 name: Annotated[ 979 str, 980 MinLen(1), 981 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 982 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 983 ] 984 """A human-readable name of this model. 985 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 986 987 outputs: NotEmpty[List[OutputTensorDescr]] 988 """Describes the output tensors.""" 989 990 @field_validator("inputs", "outputs") 991 @classmethod 992 def unique_tensor_descr_names( 993 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 994 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 995 unique_names = {str(v.name) for v in value} 996 if len(unique_names) != len(value): 997 raise ValueError("Duplicate tensor descriptor names") 998 999 return value 1000 1001 @model_validator(mode="after") 1002 def unique_io_names(self) -> Self: 1003 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1004 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1005 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1006 1007 return self 1008 1009 @model_validator(mode="after") 1010 def minimum_shape2valid_output(self) -> Self: 1011 tensors_by_name: Dict[ 1012 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1013 ] = {t.name: t for t in self.inputs + self.outputs} 1014 1015 for out in self.outputs: 1016 if isinstance(out.shape, ImplicitOutputShape): 1017 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1018 ndim_out_ref = len( 1019 [scale for scale in out.shape.scale if scale is not None] 1020 ) 1021 if ndim_ref != ndim_out_ref: 1022 expanded_dim_note = ( 1023 " Note that expanded dimensions (`scale`: null) are not" 1024 + f" counted for {out.name}'sdimensionality here." 1025 if None in out.shape.scale 1026 else "" 1027 ) 1028 raise ValueError( 1029 f"Referenced tensor '{out.shape.reference_tensor}' with" 1030 + f" {ndim_ref} dimensions does not match output tensor" 1031 + f" '{out.name}' with" 1032 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1033 ) 1034 1035 min_out_shape = self._get_min_shape(out, tensors_by_name) 1036 if out.halo: 1037 halo = out.halo 1038 halo_msg = f" for halo {out.halo}" 1039 else: 1040 halo = [0] * len(min_out_shape) 1041 halo_msg = "" 1042 1043 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1044 raise ValueError( 1045 f"Minimal shape {min_out_shape} of output {out.name} is too" 1046 + f" small{halo_msg}." 1047 ) 1048 1049 return self 1050 1051 @classmethod 1052 def _get_min_shape( 1053 cls, 1054 t: Union[InputTensorDescr, OutputTensorDescr], 1055 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1056 ) -> Sequence[int]: 1057 """output with subtracted halo has to result in meaningful output even for the minimal input 1058 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1059 """ 1060 if isinstance(t.shape, collections.abc.Sequence): 1061 return t.shape 1062 elif isinstance(t.shape, ParameterizedInputShape): 1063 return t.shape.min 1064 elif isinstance(t.shape, ImplicitOutputShape): 1065 pass 1066 else: 1067 assert_never(t.shape) 1068 1069 ref_shape = cls._get_min_shape( 1070 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1071 ) 1072 1073 if None not in t.shape.scale: 1074 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1075 else: 1076 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1077 new_ref_shape: List[int] = [] 1078 for idx in range(len(t.shape.scale)): 1079 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1080 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1081 1082 ref_shape = new_ref_shape 1083 assert len(ref_shape) == len(t.shape.scale) 1084 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1085 1086 offset = t.shape.offset 1087 assert len(offset) == len(scale) 1088 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1089 1090 @model_validator(mode="after") 1091 def validate_tensor_references_in_inputs(self) -> Self: 1092 for t in self.inputs: 1093 for proc in t.preprocessing: 1094 if "reference_tensor" not in proc.kwargs: 1095 continue 1096 1097 ref_tensor = proc.kwargs["reference_tensor"] 1098 if ref_tensor is not None and str(ref_tensor) not in { 1099 str(t.name) for t in self.inputs 1100 }: 1101 raise ValueError(f"'{ref_tensor}' not found in inputs") 1102 1103 if ref_tensor == t.name: 1104 raise ValueError( 1105 f"invalid self reference for preprocessing of tensor {t.name}" 1106 ) 1107 1108 return self 1109 1110 @model_validator(mode="after") 1111 def validate_tensor_references_in_outputs(self) -> Self: 1112 for t in self.outputs: 1113 for proc in t.postprocessing: 1114 if "reference_tensor" not in proc.kwargs: 1115 continue 1116 ref_tensor = proc.kwargs["reference_tensor"] 1117 if ref_tensor is not None and str(ref_tensor) not in { 1118 str(t.name) for t in self.inputs 1119 }: 1120 raise ValueError(f"{ref_tensor} not found in inputs") 1121 1122 return self 1123 1124 packaged_by: List[Author] = Field(default_factory=list) 1125 """The persons that have packaged and uploaded this model. 1126 Only required if those persons differ from the `authors`.""" 1127 1128 parent: Optional[LinkedModel] = None 1129 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1130 1131 @field_validator("parent", mode="before") 1132 @classmethod 1133 def ignore_url_parent(cls, parent: Any): 1134 if isinstance(parent, dict): 1135 return None 1136 1137 else: 1138 return parent 1139 1140 run_mode: Optional[RunMode] = None 1141 """Custom run mode for this model: for more complex prediction procedures like test time 1142 data augmentation that currently cannot be expressed in the specification. 1143 No standard run modes are defined yet.""" 1144 1145 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1146 """∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1147 for example stored as PNG or TIFF images. 1148 The sample files primarily serve to inform a human user about an example use case""" 1149 1150 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1151 """∈📦 URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1152 1153 test_inputs: NotEmpty[ 1154 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1155 ] 1156 """∈📦 Test input tensors compatible with the `inputs` description for a **single test case**. 1157 This means if your model has more than one input, you should provide one URL/relative path for each input. 1158 Each test input should be a file with an ndarray in 1159 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1160 The extension must be '.npy'.""" 1161 1162 test_outputs: NotEmpty[ 1163 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1164 ] 1165 """∈📦 Analog to `test_inputs`.""" 1166 1167 timestamp: Datetime 1168 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1169 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1170 1171 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1172 """The dataset used to train this model""" 1173 1174 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1175 """The weights for this model. 1176 Weights can be given for different formats, but should otherwise be equivalent. 1177 The available weight formats determine which consumers can use this model.""" 1178 1179 @model_validator(mode="before") 1180 @classmethod 1181 def _convert_from_older_format( 1182 cls, data: BioimageioYamlContent, / 1183 ) -> BioimageioYamlContent: 1184 convert_from_older_format(data) 1185 return data 1186 1187 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1188 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1189 assert all(isinstance(d, np.ndarray) for d in data) 1190 return data 1191 1192 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1193 data = [load_array(download(out).path) for out in self.test_outputs] 1194 assert all(isinstance(d, np.ndarray) for d in data) 1195 return data
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
116class CallableFromDepencency(StringNode): 117 _pattern = r"^.+\..+$" 118 _submodule_adapter = TypeAdapter(Identifier) 119 120 module_name: str 121 122 @field_validator("module_name", mode="after") 123 def check_submodules(cls, module_name: str) -> str: 124 for submod in module_name.split("."): 125 _ = cls._submodule_adapter.validate_python(submod) 126 127 return module_name 128 129 callable_name: Identifier 130 131 @classmethod 132 def _get_data(cls, valid_string_data: str): 133 *mods, callname = valid_string_data.split(".") 134 return dict(module_name=".".join(mods), callable_name=callname)
deprecated! don't use for new spec fields!
137class CallableFromFile(StringNode): 138 _pattern = r"^.+:.+$" 139 source_file: Annotated[ 140 Union[RelativeFilePath, HttpUrl], 141 Field(union_mode="left_to_right"), 142 include_in_package_serializer, 143 ] 144 """∈📦 Python module that implements `callable_name`""" 145 callable_name: Identifier 146 """The Python identifier of """ 147 148 @classmethod 149 def _get_data(cls, valid_string_data: str): 150 *file_parts, callname = valid_string_data.split(":") 151 return dict(source_file=":".join(file_parts), callable_name=callname)
deprecated! don't use for new spec fields!
∈📦 Python module that implements callable_name
159class Dependencies(StringNode): 160 _pattern = r"^.+:.+$" 161 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 162 """Dependency manager""" 163 164 file: Annotated[ 165 ImportantFileSource, 166 Field(examples=["environment.yaml", "pom.xml", "requirements.txt"]), 167 ] 168 """∈📦 Dependency file""" 169 170 @classmethod 171 def _get_data(cls, valid_string_data: str): 172 manager, *file_parts = valid_string_data.split(":") 173 return dict(manager=manager, file=":".join(file_parts))
deprecated! don't use for new spec fields!
Dependency manager
∈📦 Dependency file
186class WeightsDescr(Node): 187 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 188 onnx: Optional[OnnxWeightsDescr] = None 189 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 190 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 191 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 192 None 193 ) 194 torchscript: Optional[TorchscriptWeightsDescr] = None 195 196 @model_validator(mode="after") 197 def check_one_entry(self) -> Self: 198 if all( 199 entry is None 200 for entry in [ 201 self.keras_hdf5, 202 self.onnx, 203 self.pytorch_state_dict, 204 self.tensorflow_js, 205 self.tensorflow_saved_model_bundle, 206 self.torchscript, 207 ] 208 ): 209 raise ValueError("Missing weights entry") 210 211 return self
Subpart of a resource description
196 @model_validator(mode="after") 197 def check_one_entry(self) -> Self: 198 if all( 199 entry is None 200 for entry in [ 201 self.keras_hdf5, 202 self.onnx, 203 self.pytorch_state_dict, 204 self.tensorflow_js, 205 self.tensorflow_saved_model_bundle, 206 self.torchscript, 207 ] 208 ): 209 raise ValueError("Missing weights entry") 210 211 return self
214class WeightsEntryDescrBase(FileDescr): 215 type: ClassVar[WeightsFormat] 216 weights_format_name: ClassVar[str] # human readable 217 218 source: ImportantFileSource 219 """∈📦 The weights file.""" 220 221 attachments: Annotated[ 222 Union[AttachmentsDescr, None], 223 warn(None, "Weights entry depends on additional attachments.", ALERT), 224 ] = None 225 """Attachments that are specific to this weights entry.""" 226 227 authors: Union[List[Author], None] = None 228 """Authors 229 Either the person(s) that have trained this model resulting in the original weights file. 230 (If this is the initial weights entry, i.e. it does not have a `parent`) 231 Or the person(s) who have converted the weights to this weights format. 232 (If this is a child weight, i.e. it has a `parent` field) 233 """ 234 235 dependencies: Annotated[ 236 Optional[Dependencies], 237 warn( 238 None, 239 "Custom dependencies ({value}) specified. Avoid this whenever possible " 240 + "to allow execution in a wider range of software environments.", 241 ), 242 Field( 243 examples=[ 244 "conda:environment.yaml", 245 "maven:./pom.xml", 246 "pip:./requirements.txt", 247 ] 248 ), 249 ] = None 250 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 251 252 parent: Annotated[ 253 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 254 ] = None 255 """The source weights these weights were converted from. 256 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 257 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 258 All weight entries except one (the initial set of weights resulting from training the model), 259 need to have this field.""" 260 261 @model_validator(mode="after") 262 def check_parent_is_not_self(self) -> Self: 263 if self.type == self.parent: 264 raise ValueError("Weights entry can't be it's own parent.") 265 266 return self
Subpart of a resource description
∈📦 The weights file.
Attachments that are specific to this weights entry.
Dependency manager and dependency file, specified as <dependency manager>:<relative file path>
.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
269class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 270 type = "keras_hdf5" 271 weights_format_name: ClassVar[str] = "Keras HDF5" 272 tensorflow_version: Optional[Version] = None 273 """TensorFlow version used to create these weights""" 274 275 @field_validator("tensorflow_version", mode="after") 276 @classmethod 277 def _tfv(cls, value: Any): 278 if value is None: 279 issue_warning( 280 "missing. Please specify the TensorFlow version" 281 + " these weights were created with.", 282 value=value, 283 severity=ALERT, 284 field="tensorflow_version", 285 ) 286 return value
Subpart of a resource description
TensorFlow version used to create these weights
Inherited Members
289class OnnxWeightsDescr(WeightsEntryDescrBase): 290 type = "onnx" 291 weights_format_name: ClassVar[str] = "ONNX" 292 opset_version: Optional[Annotated[int, Ge(7)]] = None 293 """ONNX opset version""" 294 295 @field_validator("opset_version", mode="after") 296 @classmethod 297 def _ov(cls, value: Any): 298 if value is None: 299 issue_warning( 300 "Missing ONNX opset version (aka ONNX opset number). " 301 + "Please specify the ONNX opset version these weights were created" 302 + " with.", 303 value=value, 304 severity=ALERT, 305 field="opset_version", 306 ) 307 return value
Subpart of a resource description
Inherited Members
310class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 311 type = "pytorch_state_dict" 312 weights_format_name: ClassVar[str] = "Pytorch State Dict" 313 architecture: CustomCallable = Field( 314 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 315 ) 316 """callable returning a torch.nn.Module instance. 317 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 318 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 319 320 architecture_sha256: Annotated[ 321 Optional[Sha256], 322 Field( 323 description=( 324 "The SHA256 of the architecture source file, if the architecture is not" 325 " defined in a module listed in `dependencies`\n" 326 ) 327 + SHA256_HINT, 328 ), 329 ] = None 330 """The SHA256 of the architecture source file, 331 if the architecture is not defined in a module listed in `dependencies`""" 332 333 @model_validator(mode="after") 334 def check_architecture_sha256(self) -> Self: 335 if isinstance(self.architecture, CallableFromFile): 336 if self.architecture_sha256 is None: 337 raise ValueError( 338 "Missing required `architecture_sha256` for `architecture` with" 339 + " source file." 340 ) 341 elif self.architecture_sha256 is not None: 342 raise ValueError( 343 "Got `architecture_sha256` for architecture that does not have a source" 344 + " file." 345 ) 346 347 return self 348 349 kwargs: Dict[str, Any] = Field(default_factory=dict) 350 """key word arguments for the `architecture` callable""" 351 352 pytorch_version: Optional[Version] = None 353 """Version of the PyTorch library used. 354 If `depencencies` is specified it should include pytorch and the verison has to match. 355 (`dependencies` overrules `pytorch_version`)""" 356 357 @field_validator("pytorch_version", mode="after") 358 @classmethod 359 def _ptv(cls, value: Any): 360 if value is None: 361 issue_warning( 362 "missing. Please specify the PyTorch version these" 363 + " PyTorch state dict weights were created with.", 364 value=value, 365 severity=ALERT, 366 field="pytorch_version", 367 ) 368 return value
Subpart of a resource description
callable returning a torch.nn.Module instance.
Local implementation: <relative path to file>:<identifier of implementation within the file>
.
Implementation in a dependency: <dependency-package>.<[dependency-module]>.<identifier>
.
The SHA256 of the architecture source file,
if the architecture is not defined in a module listed in dependencies
333 @model_validator(mode="after") 334 def check_architecture_sha256(self) -> Self: 335 if isinstance(self.architecture, CallableFromFile): 336 if self.architecture_sha256 is None: 337 raise ValueError( 338 "Missing required `architecture_sha256` for `architecture` with" 339 + " source file." 340 ) 341 elif self.architecture_sha256 is not None: 342 raise ValueError( 343 "Got `architecture_sha256` for architecture that does not have a source" 344 + " file." 345 ) 346 347 return self
Version of the PyTorch library used.
If depencencies
is specified it should include pytorch and the verison has to match.
(dependencies
overrules pytorch_version
)
Inherited Members
371class TorchscriptWeightsDescr(WeightsEntryDescrBase): 372 type = "torchscript" 373 weights_format_name: ClassVar[str] = "TorchScript" 374 pytorch_version: Optional[Version] = None 375 """Version of the PyTorch library used.""" 376 377 @field_validator("pytorch_version", mode="after") 378 @classmethod 379 def _ptv(cls, value: Any): 380 if value is None: 381 issue_warning( 382 "missing. Please specify the PyTorch version these" 383 + " Torchscript weights were created with.", 384 value=value, 385 severity=ALERT, 386 field="pytorch_version", 387 ) 388 return value
Subpart of a resource description
Version of the PyTorch library used.
Inherited Members
391class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 392 type = "tensorflow_js" 393 weights_format_name: ClassVar[str] = "Tensorflow.js" 394 tensorflow_version: Optional[Version] = None 395 """Version of the TensorFlow library used.""" 396 397 @field_validator("tensorflow_version", mode="after") 398 @classmethod 399 def _tfv(cls, value: Any): 400 if value is None: 401 issue_warning( 402 "missing. Please specify the TensorFlow version" 403 + " these TensorflowJs weights were created with.", 404 value=value, 405 severity=ALERT, 406 field="tensorflow_version", 407 ) 408 return value 409 410 source: ImportantFileSource 411 """∈📦 The multi-file weights. 412 All required files/folders should be a zip archive."""
Subpart of a resource description
Version of the TensorFlow library used.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
415class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 416 type = "tensorflow_saved_model_bundle" 417 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 418 tensorflow_version: Optional[Version] = None 419 """Version of the TensorFlow library used.""" 420 421 @field_validator("tensorflow_version", mode="after") 422 @classmethod 423 def _tfv(cls, value: Any): 424 if value is None: 425 issue_warning( 426 "missing. Please specify the TensorFlow version" 427 + " these Tensorflow saved model bundle weights were created with.", 428 value=value, 429 severity=ALERT, 430 field="tensorflow_version", 431 ) 432 return value
Subpart of a resource description
Version of the TensorFlow library used.
Inherited Members
435class ParameterizedInputShape(Node): 436 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 437 438 min: NotEmpty[List[int]] 439 """The minimum input shape""" 440 441 step: NotEmpty[List[int]] 442 """The minimum shape change""" 443 444 def __len__(self) -> int: 445 return len(self.min) 446 447 @model_validator(mode="after") 448 def matching_lengths(self) -> Self: 449 if len(self.min) != len(self.step): 450 raise ValueError("`min` and `step` required to have the same length") 451 452 return self
A sequence of valid shapes given by shape_k = min + k * step for k in {0, 1, ...}
.
455class ImplicitOutputShape(Node): 456 """Output tensor shape depending on an input tensor shape. 457 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 458 459 reference_tensor: TensorName 460 """Name of the reference tensor.""" 461 462 scale: NotEmpty[List[Optional[float]]] 463 """output_pix/input_pix for each dimension. 464 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 465 466 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 467 """Position of origin wrt to input.""" 468 469 def __len__(self) -> int: 470 return len(self.scale) 471 472 @model_validator(mode="after") 473 def matching_lengths(self) -> Self: 474 if len(self.scale) != len(self.offset): 475 raise ValueError( 476 f"scale {self.scale} has to have same length as offset {self.offset}!" 477 ) 478 # if we have an expanded dimension, make sure that it's offet is not zero 479 for sc, off in zip(self.scale, self.offset): 480 if sc is None and not off: 481 raise ValueError("`offset` must not be zero if `scale` is none/zero") 482 483 return self
Output tensor shape depending on an input tensor shape.
shape(output_tensor) = shape(input_tensor) * scale + 2 * offset
output_pix/input_pix for each dimension.
'null' values indicate new dimensions, whose length is defined by 2*offset
Position of origin wrt to input.
472 @model_validator(mode="after") 473 def matching_lengths(self) -> Self: 474 if len(self.scale) != len(self.offset): 475 raise ValueError( 476 f"scale {self.scale} has to have same length as offset {self.offset}!" 477 ) 478 # if we have an expanded dimension, make sure that it's offet is not zero 479 for sc, off in zip(self.scale, self.offset): 480 if sc is None and not off: 481 raise ValueError("`offset` must not be zero if `scale` is none/zero") 482 483 return self
486class TensorDescrBase(Node): 487 name: TensorName 488 """Tensor name. No duplicates are allowed.""" 489 490 description: str = "" 491 492 axes: AxesStr 493 """Axes identifying characters. Same length and order as the axes in `shape`. 494 | axis | description | 495 | --- | --- | 496 | b | batch (groups multiple samples) | 497 | i | instance/index/element | 498 | t | time | 499 | c | channel | 500 | z | spatial dimension z | 501 | y | spatial dimension y | 502 | x | spatial dimension x | 503 """ 504 505 data_range: Optional[ 506 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 507 ] = None 508 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 509 If not specified, the full data range that can be expressed in `data_type` is allowed."""
Subpart of a resource description
Axes identifying characters. Same length and order as the axes in shape
.
axis | description |
---|---|
b | batch (groups multiple samples) |
i | instance/index/element |
t | time |
c | channel |
z | spatial dimension z |
y | spatial dimension y |
x | spatial dimension x |
512class ProcessingKwargs(KwargsNode): 513 """base class for pre-/postprocessing key word arguments"""
base class for pre-/postprocessing key word arguments
516class ProcessingDescrBase(NodeWithExplicitlySetFields): 517 """processing base class""" 518 519 # name: Literal[PreprocessingName, PostprocessingName] # todo: make abstract field 520 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"name"})
processing base class
523class BinarizeKwargs(ProcessingKwargs): 524 """key word arguments for `BinarizeDescr`""" 525 526 threshold: float 527 """The fixed threshold"""
key word arguments for BinarizeDescr
530class BinarizeDescr(ProcessingDescrBase): 531 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 532 Values above the threshold will be set to one, values below the threshold to zero. 533 """ 534 535 name: Literal["binarize"] = "binarize" 536 kwargs: BinarizeKwargs
BinarizeDescr the tensor with a fixed BinarizeKwargs.threshold
.
Values above the threshold will be set to one, values below the threshold to zero.
Inherited Members
539class ClipKwargs(ProcessingKwargs): 540 """key word arguments for `ClipDescr`""" 541 542 min: float 543 """minimum value for clipping""" 544 max: float 545 """maximum value for clipping"""
key word arguments for ClipDescr
548class ClipDescr(ProcessingDescrBase): 549 """Clip tensor values to a range. 550 551 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 552 and above `ClipKwargs.max` to `ClipKwargs.max`. 553 """ 554 555 name: Literal["clip"] = "clip" 556 557 kwargs: ClipKwargs
Clip tensor values to a range.
Set tensor values below ClipKwargs.min
to ClipKwargs.min
and above ClipKwargs.max
to ClipKwargs.max
.
Inherited Members
560class ScaleLinearKwargs(ProcessingKwargs): 561 """key word arguments for `ScaleLinearDescr`""" 562 563 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 564 """The subset of axes to scale jointly. 565 For example xy to scale the two image axes for 2d data jointly.""" 566 567 gain: Union[float, List[float]] = 1.0 568 """multiplicative factor""" 569 570 offset: Union[float, List[float]] = 0.0 571 """additive term""" 572 573 @model_validator(mode="after") 574 def either_gain_or_offset(self) -> Self: 575 if ( 576 self.gain == 1.0 577 or isinstance(self.gain, list) 578 and all(g == 1.0 for g in self.gain) 579 ) and ( 580 self.offset == 0.0 581 or isinstance(self.offset, list) 582 and all(off == 0.0 for off in self.offset) 583 ): 584 raise ValueError( 585 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 586 + " 0.0." 587 ) 588 589 return self
key word arguments for ScaleLinearDescr
The subset of axes to scale jointly. For example xy to scale the two image axes for 2d data jointly.
573 @model_validator(mode="after") 574 def either_gain_or_offset(self) -> Self: 575 if ( 576 self.gain == 1.0 577 or isinstance(self.gain, list) 578 and all(g == 1.0 for g in self.gain) 579 ) and ( 580 self.offset == 0.0 581 or isinstance(self.offset, list) 582 and all(off == 0.0 for off in self.offset) 583 ): 584 raise ValueError( 585 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 586 + " 0.0." 587 ) 588 589 return self
592class ScaleLinearDescr(ProcessingDescrBase): 593 """Fixed linear scaling.""" 594 595 name: Literal["scale_linear"] = "scale_linear" 596 kwargs: ScaleLinearKwargs
Fixed linear scaling.
Inherited Members
599class SigmoidDescr(ProcessingDescrBase): 600 """The logistic sigmoid funciton, a.k.a. expit function.""" 601 602 name: Literal["sigmoid"] = "sigmoid" 603 604 @property 605 def kwargs(self) -> ProcessingKwargs: 606 """empty kwargs""" 607 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
604 @property 605 def kwargs(self) -> ProcessingKwargs: 606 """empty kwargs""" 607 return ProcessingKwargs()
empty kwargs
Inherited Members
610class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 611 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 612 613 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 614 """Mode for computing mean and variance. 615 | mode | description | 616 | ----------- | ------------------------------------ | 617 | fixed | Fixed values for mean and variance | 618 | per_dataset | Compute for the entire dataset | 619 | per_sample | Compute for each sample individually | 620 """ 621 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 622 """The subset of axes to normalize jointly. 623 For example `xy` to normalize the two image axes for 2d data jointly.""" 624 625 mean: Annotated[ 626 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 627 ] = None 628 """The mean value(s) to use for `mode: fixed`. 629 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 630 # todo: check if means match input axes (for mode 'fixed') 631 632 std: Annotated[ 633 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 634 ] = None 635 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 636 637 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 638 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 639 640 @model_validator(mode="after") 641 def mean_and_std_match_mode(self) -> Self: 642 if self.mode == "fixed" and (self.mean is None or self.std is None): 643 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 644 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 645 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 646 647 return self
key word arguments for ZeroMeanUnitVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
fixed | Fixed values for mean and variance |
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to normalize jointly.
For example xy
to normalize the two image axes for 2d data jointly.
The mean value(s) to use for mode: fixed
.
For example [1.1, 2.2, 3.3]
in the case of a 3 channel image with axes: xy
.
The standard deviation values to use for mode: fixed
. Analogous to mean.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
640 @model_validator(mode="after") 641 def mean_and_std_match_mode(self) -> Self: 642 if self.mode == "fixed" and (self.mean is None or self.std is None): 643 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 644 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 645 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 646 647 return self
650class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 651 """Subtract mean and divide by variance.""" 652 653 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 654 kwargs: ZeroMeanUnitVarianceKwargs
Subtract mean and divide by variance.
Inherited Members
657class ScaleRangeKwargs(ProcessingKwargs): 658 """key word arguments for `ScaleRangeDescr` 659 660 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 661 this processing step normalizes data to the [0, 1] intervall. 662 For other percentiles the normalized values will partially be outside the [0, 1] 663 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 664 normalized values to a range. 665 """ 666 667 mode: Literal["per_dataset", "per_sample"] 668 """Mode for computing percentiles. 669 | mode | description | 670 | ----------- | ------------------------------------ | 671 | per_dataset | compute for the entire dataset | 672 | per_sample | compute for each sample individually | 673 """ 674 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 675 """The subset of axes to normalize jointly. 676 For example xy to normalize the two image axes for 2d data jointly.""" 677 678 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 679 """The lower percentile used to determine the value to align with zero.""" 680 681 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 682 """The upper percentile used to determine the value to align with one. 683 Has to be bigger than `min_percentile`. 684 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 685 accepting percentiles specified in the range 0.0 to 1.0.""" 686 687 @model_validator(mode="after") 688 def min_smaller_max(self, info: ValidationInfo) -> Self: 689 if self.min_percentile >= self.max_percentile: 690 raise ValueError( 691 f"min_percentile {self.min_percentile} >= max_percentile" 692 + f" {self.max_percentile}" 693 ) 694 695 return self 696 697 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 698 """Epsilon for numeric stability. 699 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 700 with `v_lower,v_upper` values at the respective percentiles.""" 701 702 reference_tensor: Optional[TensorName] = None 703 """Tensor name to compute the percentiles from. Default: The tensor itself. 704 For any tensor in `inputs` only input tensor references are allowed. 705 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
Mode for computing percentiles.
mode | description |
---|---|
per_dataset | compute for the entire dataset |
per_sample | compute for each sample individually |
The subset of axes to normalize jointly. For example xy to normalize the two image axes for 2d data jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor name to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
For a tensor in outputs
only input tensor refereences are allowed if mode: per_dataset
708class ScaleRangeDescr(ProcessingDescrBase): 709 """Scale with percentiles.""" 710 711 name: Literal["scale_range"] = "scale_range" 712 kwargs: ScaleRangeKwargs
Scale with percentiles.
Inherited Members
715class ScaleMeanVarianceKwargs(ProcessingKwargs): 716 """key word arguments for `ScaleMeanVarianceDescr`""" 717 718 mode: Literal["per_dataset", "per_sample"] 719 """Mode for computing mean and variance. 720 | mode | description | 721 | ----------- | ------------------------------------ | 722 | per_dataset | Compute for the entire dataset | 723 | per_sample | Compute for each sample individually | 724 """ 725 726 reference_tensor: TensorName 727 """Name of tensor to match.""" 728 729 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 730 """The subset of axes to scale jointly. 731 For example xy to normalize the two image axes for 2d data jointly. 732 Default: scale all non-batch axes jointly.""" 733 734 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 735 """Epsilon for numeric stability: 736 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
key word arguments for ScaleMeanVarianceDescr
Mode for computing mean and variance.
mode | description |
---|---|
per_dataset | Compute for the entire dataset |
per_sample | Compute for each sample individually |
The subset of axes to scale jointly. For example xy to normalize the two image axes for 2d data jointly. Default: scale all non-batch axes jointly.
739class ScaleMeanVarianceDescr(ProcessingDescrBase): 740 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 741 742 name: Literal["scale_mean_variance"] = "scale_mean_variance" 743 kwargs: ScaleMeanVarianceKwargs
Scale the tensor s.t. its mean and variance match a reference tensor.
Inherited Members
771class InputTensorDescr(TensorDescrBase): 772 data_type: Literal["float32", "uint8", "uint16"] 773 """For now an input tensor is expected to be given as `float32`. 774 The data flow in bioimage.io models is explained 775 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 776 777 shape: Annotated[ 778 Union[Sequence[int], ParameterizedInputShape], 779 Field( 780 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 781 ), 782 ] 783 """Specification of input tensor shape.""" 784 785 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 786 """Description of how this input should be preprocessed.""" 787 788 @model_validator(mode="after") 789 def zero_batch_step_and_one_batch_size(self) -> Self: 790 bidx = self.axes.find("b") 791 if bidx == -1: 792 return self 793 794 if isinstance(self.shape, ParameterizedInputShape): 795 step = self.shape.step 796 shape = self.shape.min 797 if step[bidx] != 0: 798 raise ValueError( 799 "Input shape step has to be zero in the batch dimension (the batch" 800 + " dimension can always be increased, but `step` should specify how" 801 + " to increase the minimal shape to find the largest single batch" 802 + " shape)" 803 ) 804 else: 805 shape = self.shape 806 807 if shape[bidx] != 1: 808 raise ValueError("Input shape has to be 1 in the batch dimension b.") 809 810 return self 811 812 @model_validator(mode="after") 813 def validate_preprocessing_kwargs(self) -> Self: 814 for p in self.preprocessing: 815 kwargs_axes = p.kwargs.get("axes", "") 816 if not isinstance(kwargs_axes, str): 817 raise ValueError( 818 f"Expected an `axes` string, but got {type(kwargs_axes)}" 819 ) 820 821 if any(a not in self.axes for a in kwargs_axes): 822 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 823 824 return self
Subpart of a resource description
For now an input tensor is expected to be given as float32
.
The data flow in bioimage.io models is explained
in this diagram..
Specification of input tensor shape.
Description of how this input should be preprocessed.
788 @model_validator(mode="after") 789 def zero_batch_step_and_one_batch_size(self) -> Self: 790 bidx = self.axes.find("b") 791 if bidx == -1: 792 return self 793 794 if isinstance(self.shape, ParameterizedInputShape): 795 step = self.shape.step 796 shape = self.shape.min 797 if step[bidx] != 0: 798 raise ValueError( 799 "Input shape step has to be zero in the batch dimension (the batch" 800 + " dimension can always be increased, but `step` should specify how" 801 + " to increase the minimal shape to find the largest single batch" 802 + " shape)" 803 ) 804 else: 805 shape = self.shape 806 807 if shape[bidx] != 1: 808 raise ValueError("Input shape has to be 1 in the batch dimension b.") 809 810 return self
812 @model_validator(mode="after") 813 def validate_preprocessing_kwargs(self) -> Self: 814 for p in self.preprocessing: 815 kwargs_axes = p.kwargs.get("axes", "") 816 if not isinstance(kwargs_axes, str): 817 raise ValueError( 818 f"Expected an `axes` string, but got {type(kwargs_axes)}" 819 ) 820 821 if any(a not in self.axes for a in kwargs_axes): 822 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 823 824 return self
Inherited Members
827class OutputTensorDescr(TensorDescrBase): 828 data_type: Literal[ 829 "float32", 830 "float64", 831 "uint8", 832 "int8", 833 "uint16", 834 "int16", 835 "uint32", 836 "int32", 837 "uint64", 838 "int64", 839 "bool", 840 ] 841 """Data type. 842 The data flow in bioimage.io models is explained 843 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 844 845 shape: Union[Sequence[int], ImplicitOutputShape] 846 """Output tensor shape.""" 847 848 halo: Optional[Sequence[int]] = None 849 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 850 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 851 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 852 853 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 854 """Description of how this output should be postprocessed.""" 855 856 @model_validator(mode="after") 857 def matching_halo_length(self) -> Self: 858 if self.halo and len(self.halo) != len(self.shape): 859 raise ValueError( 860 f"halo {self.halo} has to have same length as shape {self.shape}!" 861 ) 862 863 return self 864 865 @model_validator(mode="after") 866 def validate_postprocessing_kwargs(self) -> Self: 867 for p in self.postprocessing: 868 kwargs_axes = p.kwargs.get("axes", "") 869 if not isinstance(kwargs_axes, str): 870 raise ValueError(f"Expected {kwargs_axes} to be a string") 871 872 if any(a not in self.axes for a in kwargs_axes): 873 raise ValueError("`kwargs.axes` needs to be subset of axes") 874 875 return self
Subpart of a resource description
Data type. The data flow in bioimage.io models is explained in this diagram..
Description of how this output should be postprocessed.
865 @model_validator(mode="after") 866 def validate_postprocessing_kwargs(self) -> Self: 867 for p in self.postprocessing: 868 kwargs_axes = p.kwargs.get("axes", "") 869 if not isinstance(kwargs_axes, str): 870 raise ValueError(f"Expected {kwargs_axes} to be a string") 871 872 if any(a not in self.axes for a in kwargs_axes): 873 raise ValueError("`kwargs.axes` needs to be subset of axes") 874 875 return self
Inherited Members
881class RunMode(Node): 882 name: Annotated[ 883 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 884 ] 885 """Run mode name""" 886 887 kwargs: Dict[str, Any] = Field(default_factory=dict) 888 """Run mode specific key word arguments"""
Subpart of a resource description
891class LinkedModel(Node): 892 """Reference to a bioimage.io model.""" 893 894 id: ModelId 895 """A valid model `id` from the bioimage.io collection.""" 896 897 version_number: Optional[int] = None 898 """version number (n-th published version, not the semantic version) of linked model"""
Reference to a bioimage.io model.
901def package_weights( 902 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 903 handler: SerializerFunctionWrapHandler, 904 info: SerializationInfo, 905): 906 ctxt = packaging_context_var.get() 907 if ctxt is not None and ctxt.weights_priority_order is not None: 908 for wf in ctxt.weights_priority_order: 909 w = getattr(value, wf, None) 910 if w is not None: 911 break 912 else: 913 raise ValueError( 914 "None of the weight formats in `weights_priority_order`" 915 + f" ({ctxt.weights_priority_order}) is present in the given model." 916 ) 917 918 assert isinstance(w, Node), type(w) 919 # construct WeightsDescr with new single weight format entry 920 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 921 value = value.model_construct(None, **{wf: new_w}) 922 923 return handler( 924 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 925 )
928class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"): 929 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 930 931 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 932 """ 933 934 format_version: Literal["0.4.10",] = "0.4.10" 935 """Version of the bioimage.io model description specification used. 936 When creating a new model always use the latest micro/patch version described here. 937 The `format_version` is important for any consumer software to understand how to parse the fields. 938 """ 939 940 type: Literal["model"] = "model" 941 """Specialized resource type 'model'""" 942 943 id: Optional[ModelId] = None 944 """bioimage.io-wide unique resource identifier 945 assigned by bioimage.io; version **un**specific.""" 946 947 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 948 List[Author] 949 ] 950 """The authors are the creators of the model RDF and the primary points of contact.""" 951 952 documentation: Annotated[ 953 ImportantFileSource, 954 Field( 955 examples=[ 956 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 957 "README.md", 958 ], 959 ), 960 ] 961 """∈📦 URL or relative path to a markdown file with additional documentation. 962 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 963 The documentation should include a '[#[#]]# Validation' (sub)section 964 with details on how to quantitatively validate the model on unseen data.""" 965 966 inputs: NotEmpty[List[InputTensorDescr]] 967 """Describes the input tensors expected by this model.""" 968 969 license: Annotated[ 970 Union[LicenseId, str], 971 warn(LicenseId, "Unknown license id '{value}'."), 972 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 973 ] 974 """A [SPDX license identifier](https://spdx.org/licenses/). 975 We do notsupport custom license beyond the SPDX license list, if you need that please 976 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 977 ) to discuss your intentions with the community.""" 978 979 name: Annotated[ 980 str, 981 MinLen(1), 982 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 983 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 984 ] 985 """A human-readable name of this model. 986 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 987 988 outputs: NotEmpty[List[OutputTensorDescr]] 989 """Describes the output tensors.""" 990 991 @field_validator("inputs", "outputs") 992 @classmethod 993 def unique_tensor_descr_names( 994 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 995 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 996 unique_names = {str(v.name) for v in value} 997 if len(unique_names) != len(value): 998 raise ValueError("Duplicate tensor descriptor names") 999 1000 return value 1001 1002 @model_validator(mode="after") 1003 def unique_io_names(self) -> Self: 1004 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1005 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1006 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1007 1008 return self 1009 1010 @model_validator(mode="after") 1011 def minimum_shape2valid_output(self) -> Self: 1012 tensors_by_name: Dict[ 1013 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1014 ] = {t.name: t for t in self.inputs + self.outputs} 1015 1016 for out in self.outputs: 1017 if isinstance(out.shape, ImplicitOutputShape): 1018 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1019 ndim_out_ref = len( 1020 [scale for scale in out.shape.scale if scale is not None] 1021 ) 1022 if ndim_ref != ndim_out_ref: 1023 expanded_dim_note = ( 1024 " Note that expanded dimensions (`scale`: null) are not" 1025 + f" counted for {out.name}'sdimensionality here." 1026 if None in out.shape.scale 1027 else "" 1028 ) 1029 raise ValueError( 1030 f"Referenced tensor '{out.shape.reference_tensor}' with" 1031 + f" {ndim_ref} dimensions does not match output tensor" 1032 + f" '{out.name}' with" 1033 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1034 ) 1035 1036 min_out_shape = self._get_min_shape(out, tensors_by_name) 1037 if out.halo: 1038 halo = out.halo 1039 halo_msg = f" for halo {out.halo}" 1040 else: 1041 halo = [0] * len(min_out_shape) 1042 halo_msg = "" 1043 1044 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1045 raise ValueError( 1046 f"Minimal shape {min_out_shape} of output {out.name} is too" 1047 + f" small{halo_msg}." 1048 ) 1049 1050 return self 1051 1052 @classmethod 1053 def _get_min_shape( 1054 cls, 1055 t: Union[InputTensorDescr, OutputTensorDescr], 1056 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 1057 ) -> Sequence[int]: 1058 """output with subtracted halo has to result in meaningful output even for the minimal input 1059 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 1060 """ 1061 if isinstance(t.shape, collections.abc.Sequence): 1062 return t.shape 1063 elif isinstance(t.shape, ParameterizedInputShape): 1064 return t.shape.min 1065 elif isinstance(t.shape, ImplicitOutputShape): 1066 pass 1067 else: 1068 assert_never(t.shape) 1069 1070 ref_shape = cls._get_min_shape( 1071 tensors_by_name[t.shape.reference_tensor], tensors_by_name 1072 ) 1073 1074 if None not in t.shape.scale: 1075 scale: Sequence[float, ...] = t.shape.scale # type: ignore 1076 else: 1077 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 1078 new_ref_shape: List[int] = [] 1079 for idx in range(len(t.shape.scale)): 1080 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 1081 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 1082 1083 ref_shape = new_ref_shape 1084 assert len(ref_shape) == len(t.shape.scale) 1085 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 1086 1087 offset = t.shape.offset 1088 assert len(offset) == len(scale) 1089 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 1090 1091 @model_validator(mode="after") 1092 def validate_tensor_references_in_inputs(self) -> Self: 1093 for t in self.inputs: 1094 for proc in t.preprocessing: 1095 if "reference_tensor" not in proc.kwargs: 1096 continue 1097 1098 ref_tensor = proc.kwargs["reference_tensor"] 1099 if ref_tensor is not None and str(ref_tensor) not in { 1100 str(t.name) for t in self.inputs 1101 }: 1102 raise ValueError(f"'{ref_tensor}' not found in inputs") 1103 1104 if ref_tensor == t.name: 1105 raise ValueError( 1106 f"invalid self reference for preprocessing of tensor {t.name}" 1107 ) 1108 1109 return self 1110 1111 @model_validator(mode="after") 1112 def validate_tensor_references_in_outputs(self) -> Self: 1113 for t in self.outputs: 1114 for proc in t.postprocessing: 1115 if "reference_tensor" not in proc.kwargs: 1116 continue 1117 ref_tensor = proc.kwargs["reference_tensor"] 1118 if ref_tensor is not None and str(ref_tensor) not in { 1119 str(t.name) for t in self.inputs 1120 }: 1121 raise ValueError(f"{ref_tensor} not found in inputs") 1122 1123 return self 1124 1125 packaged_by: List[Author] = Field(default_factory=list) 1126 """The persons that have packaged and uploaded this model. 1127 Only required if those persons differ from the `authors`.""" 1128 1129 parent: Optional[LinkedModel] = None 1130 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 1131 1132 @field_validator("parent", mode="before") 1133 @classmethod 1134 def ignore_url_parent(cls, parent: Any): 1135 if isinstance(parent, dict): 1136 return None 1137 1138 else: 1139 return parent 1140 1141 run_mode: Optional[RunMode] = None 1142 """Custom run mode for this model: for more complex prediction procedures like test time 1143 data augmentation that currently cannot be expressed in the specification. 1144 No standard run modes are defined yet.""" 1145 1146 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 1147 """∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, 1148 for example stored as PNG or TIFF images. 1149 The sample files primarily serve to inform a human user about an example use case""" 1150 1151 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 1152 """∈📦 URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 1153 1154 test_inputs: NotEmpty[ 1155 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1156 ] 1157 """∈📦 Test input tensors compatible with the `inputs` description for a **single test case**. 1158 This means if your model has more than one input, you should provide one URL/relative path for each input. 1159 Each test input should be a file with an ndarray in 1160 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1161 The extension must be '.npy'.""" 1162 1163 test_outputs: NotEmpty[ 1164 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 1165 ] 1166 """∈📦 Analog to `test_inputs`.""" 1167 1168 timestamp: Datetime 1169 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 1170 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 1171 1172 training_data: Union[LinkedDataset, DatasetDescr, None] = None 1173 """The dataset used to train this model""" 1174 1175 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 1176 """The weights for this model. 1177 Weights can be given for different formats, but should otherwise be equivalent. 1178 The available weight formats determine which consumers can use this model.""" 1179 1180 @model_validator(mode="before") 1181 @classmethod 1182 def _convert_from_older_format( 1183 cls, data: BioimageioYamlContent, / 1184 ) -> BioimageioYamlContent: 1185 convert_from_older_format(data) 1186 return data 1187 1188 def get_input_test_arrays(self) -> List[NDArray[Any]]: 1189 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 1190 assert all(isinstance(d, np.ndarray) for d in data) 1191 return data 1192 1193 def get_output_test_arrays(self) -> List[NDArray[Any]]: 1194 data = [load_array(download(out).path) for out in self.test_outputs] 1195 assert all(isinstance(d, np.ndarray) for d in data) 1196 return data
Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
Version of the bioimage.io model description specification used.
When creating a new model always use the latest micro/patch version described here.
The format_version
is important for any consumer software to understand how to parse the fields.
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '[#[#]]# Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A SPDX license identifier. We do notsupport custom license beyond the SPDX license list, if you need that please open a GitHub issue to discuss your intentions with the community.
A human-readable name of this model. It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.
991 @field_validator("inputs", "outputs") 992 @classmethod 993 def unique_tensor_descr_names( 994 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 995 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 996 unique_names = {str(v.name) for v in value} 997 if len(unique_names) != len(value): 998 raise ValueError("Duplicate tensor descriptor names") 999 1000 return value
1002 @model_validator(mode="after") 1003 def unique_io_names(self) -> Self: 1004 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 1005 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 1006 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 1007 1008 return self
1010 @model_validator(mode="after") 1011 def minimum_shape2valid_output(self) -> Self: 1012 tensors_by_name: Dict[ 1013 TensorName, Union[InputTensorDescr, OutputTensorDescr] 1014 ] = {t.name: t for t in self.inputs + self.outputs} 1015 1016 for out in self.outputs: 1017 if isinstance(out.shape, ImplicitOutputShape): 1018 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 1019 ndim_out_ref = len( 1020 [scale for scale in out.shape.scale if scale is not None] 1021 ) 1022 if ndim_ref != ndim_out_ref: 1023 expanded_dim_note = ( 1024 " Note that expanded dimensions (`scale`: null) are not" 1025 + f" counted for {out.name}'sdimensionality here." 1026 if None in out.shape.scale 1027 else "" 1028 ) 1029 raise ValueError( 1030 f"Referenced tensor '{out.shape.reference_tensor}' with" 1031 + f" {ndim_ref} dimensions does not match output tensor" 1032 + f" '{out.name}' with" 1033 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 1034 ) 1035 1036 min_out_shape = self._get_min_shape(out, tensors_by_name) 1037 if out.halo: 1038 halo = out.halo 1039 halo_msg = f" for halo {out.halo}" 1040 else: 1041 halo = [0] * len(min_out_shape) 1042 halo_msg = "" 1043 1044 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 1045 raise ValueError( 1046 f"Minimal shape {min_out_shape} of output {out.name} is too" 1047 + f" small{halo_msg}." 1048 ) 1049 1050 return self
1091 @model_validator(mode="after") 1092 def validate_tensor_references_in_inputs(self) -> Self: 1093 for t in self.inputs: 1094 for proc in t.preprocessing: 1095 if "reference_tensor" not in proc.kwargs: 1096 continue 1097 1098 ref_tensor = proc.kwargs["reference_tensor"] 1099 if ref_tensor is not None and str(ref_tensor) not in { 1100 str(t.name) for t in self.inputs 1101 }: 1102 raise ValueError(f"'{ref_tensor}' not found in inputs") 1103 1104 if ref_tensor == t.name: 1105 raise ValueError( 1106 f"invalid self reference for preprocessing of tensor {t.name}" 1107 ) 1108 1109 return self
1111 @model_validator(mode="after") 1112 def validate_tensor_references_in_outputs(self) -> Self: 1113 for t in self.outputs: 1114 for proc in t.postprocessing: 1115 if "reference_tensor" not in proc.kwargs: 1116 continue 1117 ref_tensor = proc.kwargs["reference_tensor"] 1118 if ref_tensor is not None and str(ref_tensor) not in { 1119 str(t.name) for t in self.inputs 1120 }: 1121 raise ValueError(f"{ref_tensor} not found in inputs") 1122 1123 return self
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
∈📦 URLs/relative paths to sample inputs to illustrate possible inputs for the model, for example stored as PNG or TIFF images. The sample files primarily serve to inform a human user about an example use case
∈📦 URLs/relative paths to sample outputs corresponding to the sample_inputs
.
∈📦 Test input tensors compatible with the inputs
description for a single test case.
This means if your model has more than one input, you should provide one URL/relative path for each input.
Each test input should be a file with an ndarray in
numpy.lib file format.
The extension must be '.npy'.
∈📦 Analog to test_inputs
.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.