Coverage for bioimageio/spec/model/v0_4.py: 90%
595 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
1from __future__ import annotations
3import collections.abc
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 ClassVar,
8 Dict,
9 List,
10 Literal,
11 Optional,
12 Sequence,
13 Tuple,
14 Type,
15 Union,
16)
18import numpy as np
19from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf
20from numpy.typing import NDArray
21from pydantic import (
22 AllowInfNan,
23 Discriminator,
24 Field,
25 RootModel,
26 SerializationInfo,
27 SerializerFunctionWrapHandler,
28 StringConstraints,
29 TypeAdapter,
30 ValidationInfo,
31 WrapSerializer,
32 field_validator,
33 model_validator,
34)
35from typing_extensions import Annotated, Self, assert_never, get_args
37from .._internal.common_nodes import (
38 KwargsNode,
39 Node,
40 NodeWithExplicitlySetFields,
41)
42from .._internal.constants import SHA256_HINT
43from .._internal.field_validation import validate_unique_entries
44from .._internal.field_warning import issue_warning, warn
45from .._internal.io import (
46 BioimageioYamlContent,
47 WithSuffix,
48 download,
49 include_in_package_serializer,
50)
51from .._internal.io import FileDescr as FileDescr
52from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
53from .._internal.io_basics import Sha256 as Sha256
54from .._internal.io_utils import load_array
55from .._internal.packaging_context import packaging_context_var
56from .._internal.types import Datetime as Datetime
57from .._internal.types import Identifier as Identifier
58from .._internal.types import ImportantFileSource, LowerCaseIdentifier
59from .._internal.types import LicenseId as LicenseId
60from .._internal.types import NotEmpty as NotEmpty
61from .._internal.url import HttpUrl as HttpUrl
62from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode
63from .._internal.validator_annotations import AfterValidator, RestrictCharacters
64from .._internal.version_type import Version as Version
65from .._internal.warning_levels import ALERT, INFO
66from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS
67from ..dataset.v0_2 import DatasetDescr as DatasetDescr
68from ..dataset.v0_2 import LinkedDataset as LinkedDataset
69from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr
70from ..generic.v0_2 import Author as Author
71from ..generic.v0_2 import BadgeDescr as BadgeDescr
72from ..generic.v0_2 import CiteEntry as CiteEntry
73from ..generic.v0_2 import Doi as Doi
74from ..generic.v0_2 import GenericModelDescrBase
75from ..generic.v0_2 import LinkedResource as LinkedResource
76from ..generic.v0_2 import Maintainer as Maintainer
77from ..generic.v0_2 import OrcidId as OrcidId
78from ..generic.v0_2 import RelativeFilePath as RelativeFilePath
79from ..generic.v0_2 import ResourceId as ResourceId
80from ..generic.v0_2 import Uploader as Uploader
81from ._v0_4_converter import convert_from_older_format
84class ModelId(ResourceId):
85 pass
88AxesStr = Annotated[
89 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries)
90]
91AxesInCZYX = Annotated[
92 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries)
93]
95PostprocessingName = Literal[
96 "binarize",
97 "clip",
98 "scale_linear",
99 "sigmoid",
100 "zero_mean_unit_variance",
101 "scale_range",
102 "scale_mean_variance",
103]
104PreprocessingName = Literal[
105 "binarize",
106 "clip",
107 "scale_linear",
108 "sigmoid",
109 "zero_mean_unit_variance",
110 "scale_range",
111]
114class TensorName(LowerCaseIdentifier):
115 pass
118class CallableFromDepencencyNode(Node):
119 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier)
121 module_name: str
122 """The Python module that implements **callable_name**."""
124 @field_validator("module_name", mode="after")
125 def _check_submodules(cls, module_name: str) -> str:
126 for submod in module_name.split("."):
127 _ = cls._submodule_adapter.validate_python(submod)
129 return module_name
131 callable_name: Identifier
132 """The callable Python identifier implemented in module **module_name**."""
135class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]):
136 _inner_node_class = CallableFromDepencencyNode
137 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
138 Annotated[
139 str,
140 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"),
141 ]
142 ]
144 @classmethod
145 def _get_data(cls, valid_string_data: str):
146 *mods, callname = valid_string_data.split(".")
147 return dict(module_name=".".join(mods), callable_name=callname)
149 @property
150 def module_name(self):
151 """The Python module that implements **callable_name**."""
152 return self._inner_node.module_name
154 @property
155 def callable_name(self):
156 """The callable Python identifier implemented in module **module_name**."""
157 return self._inner_node.callable_name
160class CallableFromFileNode(Node):
161 source_file: Annotated[
162 Union[RelativeFilePath, HttpUrl],
163 Field(union_mode="left_to_right"),
164 include_in_package_serializer,
165 ]
166 """The Python source file that implements **callable_name**."""
167 callable_name: Identifier
168 """The callable Python identifier implemented in **source_file**."""
171class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]):
172 _inner_node_class = CallableFromFileNode
173 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
174 Annotated[
175 str,
176 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
177 ]
178 ]
180 @classmethod
181 def _get_data(cls, valid_string_data: str):
182 *file_parts, callname = valid_string_data.split(":")
183 return dict(source_file=":".join(file_parts), callable_name=callname)
185 @property
186 def source_file(self):
187 """The Python source file that implements **callable_name**."""
188 return self._inner_node.source_file
190 @property
191 def callable_name(self):
192 """The callable Python identifier implemented in **source_file**."""
193 return self._inner_node.callable_name
196CustomCallable = Annotated[
197 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right")
198]
201class DependenciesNode(Node):
202 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])]
203 """Dependency manager"""
205 file: ImportantFileSource
206 """Dependency file"""
209class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]):
210 _inner_node_class = DependenciesNode
211 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
212 Annotated[
213 str,
214 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
215 ]
216 ]
218 @classmethod
219 def _get_data(cls, valid_string_data: str):
220 manager, *file_parts = valid_string_data.split(":")
221 return dict(manager=manager, file=":".join(file_parts))
223 @property
224 def manager(self):
225 """Dependency manager"""
226 return self._inner_node.manager
228 @property
229 def file(self):
230 """Dependency file"""
231 return self._inner_node.file
234WeightsFormat = Literal[
235 "keras_hdf5",
236 "onnx",
237 "pytorch_state_dict",
238 "tensorflow_js",
239 "tensorflow_saved_model_bundle",
240 "torchscript",
241]
244class WeightsDescr(Node):
245 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
246 onnx: Optional[OnnxWeightsDescr] = None
247 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
248 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
249 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
250 None
251 )
252 torchscript: Optional[TorchscriptWeightsDescr] = None
254 @model_validator(mode="after")
255 def check_one_entry(self) -> Self:
256 if all(
257 entry is None
258 for entry in [
259 self.keras_hdf5,
260 self.onnx,
261 self.pytorch_state_dict,
262 self.tensorflow_js,
263 self.tensorflow_saved_model_bundle,
264 self.torchscript,
265 ]
266 ):
267 raise ValueError("Missing weights entry")
269 return self
271 def __getitem__(
272 self,
273 key: WeightsFormat,
274 ):
275 if key == "keras_hdf5":
276 ret = self.keras_hdf5
277 elif key == "onnx":
278 ret = self.onnx
279 elif key == "pytorch_state_dict":
280 ret = self.pytorch_state_dict
281 elif key == "tensorflow_js":
282 ret = self.tensorflow_js
283 elif key == "tensorflow_saved_model_bundle":
284 ret = self.tensorflow_saved_model_bundle
285 elif key == "torchscript":
286 ret = self.torchscript
287 else:
288 raise KeyError(key)
290 if ret is None:
291 raise KeyError(key)
293 return ret
295 @property
296 def available_formats(self):
297 return {
298 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
299 **({} if self.onnx is None else {"onnx": self.onnx}),
300 **(
301 {}
302 if self.pytorch_state_dict is None
303 else {"pytorch_state_dict": self.pytorch_state_dict}
304 ),
305 **(
306 {}
307 if self.tensorflow_js is None
308 else {"tensorflow_js": self.tensorflow_js}
309 ),
310 **(
311 {}
312 if self.tensorflow_saved_model_bundle is None
313 else {
314 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
315 }
316 ),
317 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
318 }
320 @property
321 def missing_formats(self):
322 return {
323 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
324 }
327class WeightsEntryDescrBase(FileDescr):
328 type: ClassVar[WeightsFormat]
329 weights_format_name: ClassVar[str] # human readable
331 source: ImportantFileSource
332 """The weights file."""
334 attachments: Annotated[
335 Union[AttachmentsDescr, None],
336 warn(None, "Weights entry depends on additional attachments.", ALERT),
337 ] = None
338 """Attachments that are specific to this weights entry."""
340 authors: Union[List[Author], None] = None
341 """Authors
342 Either the person(s) that have trained this model resulting in the original weights file.
343 (If this is the initial weights entry, i.e. it does not have a `parent`)
344 Or the person(s) who have converted the weights to this weights format.
345 (If this is a child weight, i.e. it has a `parent` field)
346 """
348 dependencies: Annotated[
349 Optional[Dependencies],
350 warn(
351 None,
352 "Custom dependencies ({value}) specified. Avoid this whenever possible "
353 + "to allow execution in a wider range of software environments.",
354 ),
355 Field(
356 examples=[
357 "conda:environment.yaml",
358 "maven:./pom.xml",
359 "pip:./requirements.txt",
360 ]
361 ),
362 ] = None
363 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`."""
365 parent: Annotated[
366 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
367 ] = None
368 """The source weights these weights were converted from.
369 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
370 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
371 All weight entries except one (the initial set of weights resulting from training the model),
372 need to have this field."""
374 @model_validator(mode="after")
375 def check_parent_is_not_self(self) -> Self:
376 if self.type == self.parent:
377 raise ValueError("Weights entry can't be it's own parent.")
379 return self
382class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
383 type = "keras_hdf5"
384 weights_format_name: ClassVar[str] = "Keras HDF5"
385 tensorflow_version: Optional[Version] = None
386 """TensorFlow version used to create these weights"""
388 @field_validator("tensorflow_version", mode="after")
389 @classmethod
390 def _tfv(cls, value: Any):
391 if value is None:
392 issue_warning(
393 "missing. Please specify the TensorFlow version"
394 + " these weights were created with.",
395 value=value,
396 severity=ALERT,
397 field="tensorflow_version",
398 )
399 return value
402class OnnxWeightsDescr(WeightsEntryDescrBase):
403 type = "onnx"
404 weights_format_name: ClassVar[str] = "ONNX"
405 opset_version: Optional[Annotated[int, Ge(7)]] = None
406 """ONNX opset version"""
408 @field_validator("opset_version", mode="after")
409 @classmethod
410 def _ov(cls, value: Any):
411 if value is None:
412 issue_warning(
413 "Missing ONNX opset version (aka ONNX opset number). "
414 + "Please specify the ONNX opset version these weights were created"
415 + " with.",
416 value=value,
417 severity=ALERT,
418 field="opset_version",
419 )
420 return value
423class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
424 type = "pytorch_state_dict"
425 weights_format_name: ClassVar[str] = "Pytorch State Dict"
426 architecture: CustomCallable = Field(
427 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"]
428 )
429 """callable returning a torch.nn.Module instance.
430 Local implementation: `<relative path to file>:<identifier of implementation within the file>`.
431 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`."""
433 architecture_sha256: Annotated[
434 Optional[Sha256],
435 Field(
436 description=(
437 "The SHA256 of the architecture source file, if the architecture is not"
438 " defined in a module listed in `dependencies`\n"
439 )
440 + SHA256_HINT,
441 ),
442 ] = None
443 """The SHA256 of the architecture source file,
444 if the architecture is not defined in a module listed in `dependencies`"""
446 @model_validator(mode="after")
447 def check_architecture_sha256(self) -> Self:
448 if isinstance(self.architecture, CallableFromFile):
449 if self.architecture_sha256 is None:
450 raise ValueError(
451 "Missing required `architecture_sha256` for `architecture` with"
452 + " source file."
453 )
454 elif self.architecture_sha256 is not None:
455 raise ValueError(
456 "Got `architecture_sha256` for architecture that does not have a source"
457 + " file."
458 )
460 return self
462 kwargs: Dict[str, Any] = Field(default_factory=dict)
463 """key word arguments for the `architecture` callable"""
465 pytorch_version: Optional[Version] = None
466 """Version of the PyTorch library used.
467 If `depencencies` is specified it should include pytorch and the verison has to match.
468 (`dependencies` overrules `pytorch_version`)"""
470 @field_validator("pytorch_version", mode="after")
471 @classmethod
472 def _ptv(cls, value: Any):
473 if value is None:
474 issue_warning(
475 "missing. Please specify the PyTorch version these"
476 + " PyTorch state dict weights were created with.",
477 value=value,
478 severity=ALERT,
479 field="pytorch_version",
480 )
481 return value
484class TorchscriptWeightsDescr(WeightsEntryDescrBase):
485 type = "torchscript"
486 weights_format_name: ClassVar[str] = "TorchScript"
487 pytorch_version: Optional[Version] = None
488 """Version of the PyTorch library used."""
490 @field_validator("pytorch_version", mode="after")
491 @classmethod
492 def _ptv(cls, value: Any):
493 if value is None:
494 issue_warning(
495 "missing. Please specify the PyTorch version these"
496 + " Torchscript weights were created with.",
497 value=value,
498 severity=ALERT,
499 field="pytorch_version",
500 )
501 return value
504class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
505 type = "tensorflow_js"
506 weights_format_name: ClassVar[str] = "Tensorflow.js"
507 tensorflow_version: Optional[Version] = None
508 """Version of the TensorFlow library used."""
510 @field_validator("tensorflow_version", mode="after")
511 @classmethod
512 def _tfv(cls, value: Any):
513 if value is None:
514 issue_warning(
515 "missing. Please specify the TensorFlow version"
516 + " these TensorflowJs weights were created with.",
517 value=value,
518 severity=ALERT,
519 field="tensorflow_version",
520 )
521 return value
523 source: ImportantFileSource
524 """The multi-file weights.
525 All required files/folders should be a zip archive."""
528class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
529 type = "tensorflow_saved_model_bundle"
530 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
531 tensorflow_version: Optional[Version] = None
532 """Version of the TensorFlow library used."""
534 @field_validator("tensorflow_version", mode="after")
535 @classmethod
536 def _tfv(cls, value: Any):
537 if value is None:
538 issue_warning(
539 "missing. Please specify the TensorFlow version"
540 + " these Tensorflow saved model bundle weights were created with.",
541 value=value,
542 severity=ALERT,
543 field="tensorflow_version",
544 )
545 return value
548class ParameterizedInputShape(Node):
549 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
551 min: NotEmpty[List[int]]
552 """The minimum input shape"""
554 step: NotEmpty[List[int]]
555 """The minimum shape change"""
557 def __len__(self) -> int:
558 return len(self.min)
560 @model_validator(mode="after")
561 def matching_lengths(self) -> Self:
562 if len(self.min) != len(self.step):
563 raise ValueError("`min` and `step` required to have the same length")
565 return self
568class ImplicitOutputShape(Node):
569 """Output tensor shape depending on an input tensor shape.
570 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`"""
572 reference_tensor: TensorName
573 """Name of the reference tensor."""
575 scale: NotEmpty[List[Optional[float]]]
576 """output_pix/input_pix for each dimension.
577 'null' values indicate new dimensions, whose length is defined by 2*`offset`"""
579 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]]
580 """Position of origin wrt to input."""
582 def __len__(self) -> int:
583 return len(self.scale)
585 @model_validator(mode="after")
586 def matching_lengths(self) -> Self:
587 if len(self.scale) != len(self.offset):
588 raise ValueError(
589 f"scale {self.scale} has to have same length as offset {self.offset}!"
590 )
591 # if we have an expanded dimension, make sure that it's offet is not zero
592 for sc, off in zip(self.scale, self.offset):
593 if sc is None and not off:
594 raise ValueError("`offset` must not be zero if `scale` is none/zero")
596 return self
599class TensorDescrBase(Node):
600 name: TensorName
601 """Tensor name. No duplicates are allowed."""
603 description: str = ""
605 axes: AxesStr
606 """Axes identifying characters. Same length and order as the axes in `shape`.
607 | axis | description |
608 | --- | --- |
609 | b | batch (groups multiple samples) |
610 | i | instance/index/element |
611 | t | time |
612 | c | channel |
613 | z | spatial dimension z |
614 | y | spatial dimension y |
615 | x | spatial dimension x |
616 """
618 data_range: Optional[
619 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]]
620 ] = None
621 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
622 If not specified, the full data range that can be expressed in `data_type` is allowed."""
625class ProcessingKwargs(KwargsNode):
626 """base class for pre-/postprocessing key word arguments"""
629class ProcessingDescrBase(NodeWithExplicitlySetFields):
630 """processing base class"""
633class BinarizeKwargs(ProcessingKwargs):
634 """key word arguments for `BinarizeDescr`"""
636 threshold: float
637 """The fixed threshold"""
640class BinarizeDescr(ProcessingDescrBase):
641 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`.
642 Values above the threshold will be set to one, values below the threshold to zero.
643 """
645 implemented_name: ClassVar[Literal["binarize"]] = "binarize"
646 if TYPE_CHECKING:
647 name: Literal["binarize"] = "binarize"
648 else:
649 name: Literal["binarize"]
651 kwargs: BinarizeKwargs
654class ClipKwargs(ProcessingKwargs):
655 """key word arguments for `ClipDescr`"""
657 min: float
658 """minimum value for clipping"""
659 max: float
660 """maximum value for clipping"""
663class ClipDescr(ProcessingDescrBase):
664 """Clip tensor values to a range.
666 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min`
667 and above `ClipKwargs.max` to `ClipKwargs.max`.
668 """
670 implemented_name: ClassVar[Literal["clip"]] = "clip"
671 if TYPE_CHECKING:
672 name: Literal["clip"] = "clip"
673 else:
674 name: Literal["clip"]
676 kwargs: ClipKwargs
679class ScaleLinearKwargs(ProcessingKwargs):
680 """key word arguments for `ScaleLinearDescr`"""
682 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
683 """The subset of axes to scale jointly.
684 For example xy to scale the two image axes for 2d data jointly."""
686 gain: Union[float, List[float]] = 1.0
687 """multiplicative factor"""
689 offset: Union[float, List[float]] = 0.0
690 """additive term"""
692 @model_validator(mode="after")
693 def either_gain_or_offset(self) -> Self:
694 if (
695 self.gain == 1.0
696 or isinstance(self.gain, list)
697 and all(g == 1.0 for g in self.gain)
698 ) and (
699 self.offset == 0.0
700 or isinstance(self.offset, list)
701 and all(off == 0.0 for off in self.offset)
702 ):
703 raise ValueError(
704 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !="
705 + " 0.0."
706 )
708 return self
711class ScaleLinearDescr(ProcessingDescrBase):
712 """Fixed linear scaling."""
714 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear"
715 if TYPE_CHECKING:
716 name: Literal["scale_linear"] = "scale_linear"
717 else:
718 name: Literal["scale_linear"]
720 kwargs: ScaleLinearKwargs
723class SigmoidDescr(ProcessingDescrBase):
724 """The logistic sigmoid funciton, a.k.a. expit function."""
726 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid"
727 if TYPE_CHECKING:
728 name: Literal["sigmoid"] = "sigmoid"
729 else:
730 name: Literal["sigmoid"]
732 @property
733 def kwargs(self) -> ProcessingKwargs:
734 """empty kwargs"""
735 return ProcessingKwargs()
738class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
739 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
741 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed"
742 """Mode for computing mean and variance.
743 | mode | description |
744 | ----------- | ------------------------------------ |
745 | fixed | Fixed values for mean and variance |
746 | per_dataset | Compute for the entire dataset |
747 | per_sample | Compute for each sample individually |
748 """
749 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
750 """The subset of axes to normalize jointly.
751 For example `xy` to normalize the two image axes for 2d data jointly."""
753 mean: Annotated[
754 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)])
755 ] = None
756 """The mean value(s) to use for `mode: fixed`.
757 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`."""
758 # todo: check if means match input axes (for mode 'fixed')
760 std: Annotated[
761 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)])
762 ] = None
763 """The standard deviation values to use for `mode: fixed`. Analogous to mean."""
765 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
766 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
768 @model_validator(mode="after")
769 def mean_and_std_match_mode(self) -> Self:
770 if self.mode == "fixed" and (self.mean is None or self.std is None):
771 raise ValueError("`mean` and `std` are required for `mode: fixed`.")
772 elif self.mode != "fixed" and (self.mean is not None or self.std is not None):
773 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`")
775 return self
778class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
779 """Subtract mean and divide by variance."""
781 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = (
782 "zero_mean_unit_variance"
783 )
784 if TYPE_CHECKING:
785 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
786 else:
787 name: Literal["zero_mean_unit_variance"]
789 kwargs: ZeroMeanUnitVarianceKwargs
792class ScaleRangeKwargs(ProcessingKwargs):
793 """key word arguments for `ScaleRangeDescr`
795 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
796 this processing step normalizes data to the [0, 1] intervall.
797 For other percentiles the normalized values will partially be outside the [0, 1]
798 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
799 normalized values to a range.
800 """
802 mode: Literal["per_dataset", "per_sample"]
803 """Mode for computing percentiles.
804 | mode | description |
805 | ----------- | ------------------------------------ |
806 | per_dataset | compute for the entire dataset |
807 | per_sample | compute for each sample individually |
808 """
809 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
810 """The subset of axes to normalize jointly.
811 For example xy to normalize the two image axes for 2d data jointly."""
813 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0
814 """The lower percentile used to determine the value to align with zero."""
816 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0
817 """The upper percentile used to determine the value to align with one.
818 Has to be bigger than `min_percentile`.
819 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
820 accepting percentiles specified in the range 0.0 to 1.0."""
822 @model_validator(mode="after")
823 def min_smaller_max(self, info: ValidationInfo) -> Self:
824 if self.min_percentile >= self.max_percentile:
825 raise ValueError(
826 f"min_percentile {self.min_percentile} >= max_percentile"
827 + f" {self.max_percentile}"
828 )
830 return self
832 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
833 """Epsilon for numeric stability.
834 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
835 with `v_lower,v_upper` values at the respective percentiles."""
837 reference_tensor: Optional[TensorName] = None
838 """Tensor name to compute the percentiles from. Default: The tensor itself.
839 For any tensor in `inputs` only input tensor references are allowed.
840 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
843class ScaleRangeDescr(ProcessingDescrBase):
844 """Scale with percentiles."""
846 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range"
847 if TYPE_CHECKING:
848 name: Literal["scale_range"] = "scale_range"
849 else:
850 name: Literal["scale_range"]
852 kwargs: ScaleRangeKwargs
855class ScaleMeanVarianceKwargs(ProcessingKwargs):
856 """key word arguments for `ScaleMeanVarianceDescr`"""
858 mode: Literal["per_dataset", "per_sample"]
859 """Mode for computing mean and variance.
860 | mode | description |
861 | ----------- | ------------------------------------ |
862 | per_dataset | Compute for the entire dataset |
863 | per_sample | Compute for each sample individually |
864 """
866 reference_tensor: TensorName
867 """Name of tensor to match."""
869 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
870 """The subset of axes to scale jointly.
871 For example xy to normalize the two image axes for 2d data jointly.
872 Default: scale all non-batch axes jointly."""
874 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
875 """Epsilon for numeric stability:
876 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
879class ScaleMeanVarianceDescr(ProcessingDescrBase):
880 """Scale the tensor s.t. its mean and variance match a reference tensor."""
882 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
883 if TYPE_CHECKING:
884 name: Literal["scale_mean_variance"] = "scale_mean_variance"
885 else:
886 name: Literal["scale_mean_variance"]
888 kwargs: ScaleMeanVarianceKwargs
891PreprocessingDescr = Annotated[
892 Union[
893 BinarizeDescr,
894 ClipDescr,
895 ScaleLinearDescr,
896 SigmoidDescr,
897 ZeroMeanUnitVarianceDescr,
898 ScaleRangeDescr,
899 ],
900 Discriminator("name"),
901]
902PostprocessingDescr = Annotated[
903 Union[
904 BinarizeDescr,
905 ClipDescr,
906 ScaleLinearDescr,
907 SigmoidDescr,
908 ZeroMeanUnitVarianceDescr,
909 ScaleRangeDescr,
910 ScaleMeanVarianceDescr,
911 ],
912 Discriminator("name"),
913]
916class InputTensorDescr(TensorDescrBase):
917 data_type: Literal["float32", "uint8", "uint16"]
918 """For now an input tensor is expected to be given as `float32`.
919 The data flow in bioimage.io models is explained
920 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
922 shape: Annotated[
923 Union[Sequence[int], ParameterizedInputShape],
924 Field(
925 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))]
926 ),
927 ]
928 """Specification of input tensor shape."""
930 preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
931 """Description of how this input should be preprocessed."""
933 @model_validator(mode="after")
934 def zero_batch_step_and_one_batch_size(self) -> Self:
935 bidx = self.axes.find("b")
936 if bidx == -1:
937 return self
939 if isinstance(self.shape, ParameterizedInputShape):
940 step = self.shape.step
941 shape = self.shape.min
942 if step[bidx] != 0:
943 raise ValueError(
944 "Input shape step has to be zero in the batch dimension (the batch"
945 + " dimension can always be increased, but `step` should specify how"
946 + " to increase the minimal shape to find the largest single batch"
947 + " shape)"
948 )
949 else:
950 shape = self.shape
952 if shape[bidx] != 1:
953 raise ValueError("Input shape has to be 1 in the batch dimension b.")
955 return self
957 @model_validator(mode="after")
958 def validate_preprocessing_kwargs(self) -> Self:
959 for p in self.preprocessing:
960 kwargs_axes = p.kwargs.get("axes")
961 if isinstance(kwargs_axes, str) and any(
962 a not in self.axes for a in kwargs_axes
963 ):
964 raise ValueError("`kwargs.axes` needs to be subset of `axes`")
966 return self
969class OutputTensorDescr(TensorDescrBase):
970 data_type: Literal[
971 "float32",
972 "float64",
973 "uint8",
974 "int8",
975 "uint16",
976 "int16",
977 "uint32",
978 "int32",
979 "uint64",
980 "int64",
981 "bool",
982 ]
983 """Data type.
984 The data flow in bioimage.io models is explained
985 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
987 shape: Union[Sequence[int], ImplicitOutputShape]
988 """Output tensor shape."""
990 halo: Optional[Sequence[int]] = None
991 """The `halo` that should be cropped from the output tensor to avoid boundary effects.
992 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`.
993 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead."""
995 postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
996 """Description of how this output should be postprocessed."""
998 @model_validator(mode="after")
999 def matching_halo_length(self) -> Self:
1000 if self.halo and len(self.halo) != len(self.shape):
1001 raise ValueError(
1002 f"halo {self.halo} has to have same length as shape {self.shape}!"
1003 )
1005 return self
1007 @model_validator(mode="after")
1008 def validate_postprocessing_kwargs(self) -> Self:
1009 for p in self.postprocessing:
1010 kwargs_axes = p.kwargs.get("axes", "")
1011 if not isinstance(kwargs_axes, str):
1012 raise ValueError(f"Expected {kwargs_axes} to be a string")
1014 if any(a not in self.axes for a in kwargs_axes):
1015 raise ValueError("`kwargs.axes` needs to be subset of axes")
1017 return self
1020KnownRunMode = Literal["deepimagej"]
1023class RunMode(Node):
1024 name: Annotated[
1025 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.")
1026 ]
1027 """Run mode name"""
1029 kwargs: Dict[str, Any] = Field(default_factory=dict)
1030 """Run mode specific key word arguments"""
1033class LinkedModel(Node):
1034 """Reference to a bioimage.io model."""
1036 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])]
1037 """A valid model `id` from the bioimage.io collection."""
1039 version_number: Optional[int] = None
1040 """version number (n-th published version, not the semantic version) of linked model"""
1043def package_weights(
1044 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr]
1045 handler: SerializerFunctionWrapHandler,
1046 info: SerializationInfo,
1047):
1048 ctxt = packaging_context_var.get()
1049 if ctxt is not None and ctxt.weights_priority_order is not None:
1050 for wf in ctxt.weights_priority_order:
1051 w = getattr(value, wf, None)
1052 if w is not None:
1053 break
1054 else:
1055 raise ValueError(
1056 "None of the weight formats in `weights_priority_order`"
1057 + f" ({ctxt.weights_priority_order}) is present in the given model."
1058 )
1060 assert isinstance(w, Node), type(w)
1061 # construct WeightsDescr with new single weight format entry
1062 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"})
1063 value = value.model_construct(None, **{wf: new_w})
1065 return handler(
1066 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs
1067 )
1070class ModelDescr(GenericModelDescrBase):
1071 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
1073 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
1074 """
1076 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10"
1077 if TYPE_CHECKING:
1078 format_version: Literal["0.4.10"] = "0.4.10"
1079 else:
1080 format_version: Literal["0.4.10"]
1081 """Version of the bioimage.io model description specification used.
1082 When creating a new model always use the latest micro/patch version described here.
1083 The `format_version` is important for any consumer software to understand how to parse the fields.
1084 """
1086 implemented_type: ClassVar[Literal["model"]] = "model"
1087 if TYPE_CHECKING:
1088 type: Literal["model"] = "model"
1089 else:
1090 type: Literal["model"]
1091 """Specialized resource type 'model'"""
1093 id: Optional[ModelId] = None
1094 """bioimage.io-wide unique resource identifier
1095 assigned by bioimage.io; version **un**specific."""
1097 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
1098 List[Author]
1099 ]
1100 """The authors are the creators of the model RDF and the primary points of contact."""
1102 documentation: Annotated[
1103 ImportantFileSource,
1104 Field(
1105 examples=[
1106 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
1107 "README.md",
1108 ],
1109 ),
1110 ]
1111 """URL or relative path to a markdown file with additional documentation.
1112 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
1113 The documentation should include a '[#[#]]# Validation' (sub)section
1114 with details on how to quantitatively validate the model on unseen data."""
1116 inputs: NotEmpty[List[InputTensorDescr]]
1117 """Describes the input tensors expected by this model."""
1119 license: Annotated[
1120 Union[LicenseId, str],
1121 warn(LicenseId, "Unknown license id '{value}'."),
1122 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]),
1123 ]
1124 """A [SPDX license identifier](https://spdx.org/licenses/).
1125 We do notsupport custom license beyond the SPDX license list, if you need that please
1126 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose
1127 ) to discuss your intentions with the community."""
1129 name: Annotated[
1130 str,
1131 MinLen(1),
1132 warn(MinLen(5), "Name shorter than 5 characters.", INFO),
1133 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
1134 ]
1135 """A human-readable name of this model.
1136 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters."""
1138 outputs: NotEmpty[List[OutputTensorDescr]]
1139 """Describes the output tensors."""
1141 @field_validator("inputs", "outputs")
1142 @classmethod
1143 def unique_tensor_descr_names(
1144 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]]
1145 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]:
1146 unique_names = {str(v.name) for v in value}
1147 if len(unique_names) != len(value):
1148 raise ValueError("Duplicate tensor descriptor names")
1150 return value
1152 @model_validator(mode="after")
1153 def unique_io_names(self) -> Self:
1154 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s}
1155 if len(unique_names) != (len(self.inputs) + len(self.outputs)):
1156 raise ValueError("Duplicate tensor descriptor names across inputs/outputs")
1158 return self
1160 @model_validator(mode="after")
1161 def minimum_shape2valid_output(self) -> Self:
1162 tensors_by_name: Dict[
1163 TensorName, Union[InputTensorDescr, OutputTensorDescr]
1164 ] = {t.name: t for t in self.inputs + self.outputs}
1166 for out in self.outputs:
1167 if isinstance(out.shape, ImplicitOutputShape):
1168 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape)
1169 ndim_out_ref = len(
1170 [scale for scale in out.shape.scale if scale is not None]
1171 )
1172 if ndim_ref != ndim_out_ref:
1173 expanded_dim_note = (
1174 " Note that expanded dimensions (`scale`: null) are not"
1175 + f" counted for {out.name}'sdimensionality here."
1176 if None in out.shape.scale
1177 else ""
1178 )
1179 raise ValueError(
1180 f"Referenced tensor '{out.shape.reference_tensor}' with"
1181 + f" {ndim_ref} dimensions does not match output tensor"
1182 + f" '{out.name}' with"
1183 + f" {ndim_out_ref} dimensions.{expanded_dim_note}"
1184 )
1186 min_out_shape = self._get_min_shape(out, tensors_by_name)
1187 if out.halo:
1188 halo = out.halo
1189 halo_msg = f" for halo {out.halo}"
1190 else:
1191 halo = [0] * len(min_out_shape)
1192 halo_msg = ""
1194 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]):
1195 raise ValueError(
1196 f"Minimal shape {min_out_shape} of output {out.name} is too"
1197 + f" small{halo_msg}."
1198 )
1200 return self
1202 @classmethod
1203 def _get_min_shape(
1204 cls,
1205 t: Union[InputTensorDescr, OutputTensorDescr],
1206 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]],
1207 ) -> Sequence[int]:
1208 """output with subtracted halo has to result in meaningful output even for the minimal input
1209 see https://github.com/bioimage-io/spec-bioimage-io/issues/392
1210 """
1211 if isinstance(t.shape, collections.abc.Sequence):
1212 return t.shape
1213 elif isinstance(t.shape, ParameterizedInputShape):
1214 return t.shape.min
1215 elif isinstance(t.shape, ImplicitOutputShape):
1216 pass
1217 else:
1218 assert_never(t.shape)
1220 ref_shape = cls._get_min_shape(
1221 tensors_by_name[t.shape.reference_tensor], tensors_by_name
1222 )
1224 if None not in t.shape.scale:
1225 scale: Sequence[float, ...] = t.shape.scale # type: ignore
1226 else:
1227 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None]
1228 new_ref_shape: List[int] = []
1229 for idx in range(len(t.shape.scale)):
1230 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims)
1231 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx])
1233 ref_shape = new_ref_shape
1234 assert len(ref_shape) == len(t.shape.scale)
1235 scale = [0.0 if sc is None else sc for sc in t.shape.scale]
1237 offset = t.shape.offset
1238 assert len(offset) == len(scale)
1239 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)]
1241 @model_validator(mode="after")
1242 def validate_tensor_references_in_inputs(self) -> Self:
1243 for t in self.inputs:
1244 for proc in t.preprocessing:
1245 if "reference_tensor" not in proc.kwargs:
1246 continue
1248 ref_tensor = proc.kwargs["reference_tensor"]
1249 if ref_tensor is not None and str(ref_tensor) not in {
1250 str(t.name) for t in self.inputs
1251 }:
1252 raise ValueError(f"'{ref_tensor}' not found in inputs")
1254 if ref_tensor == t.name:
1255 raise ValueError(
1256 f"invalid self reference for preprocessing of tensor {t.name}"
1257 )
1259 return self
1261 @model_validator(mode="after")
1262 def validate_tensor_references_in_outputs(self) -> Self:
1263 for t in self.outputs:
1264 for proc in t.postprocessing:
1265 if "reference_tensor" not in proc.kwargs:
1266 continue
1267 ref_tensor = proc.kwargs["reference_tensor"]
1268 if ref_tensor is not None and str(ref_tensor) not in {
1269 str(t.name) for t in self.inputs
1270 }:
1271 raise ValueError(f"{ref_tensor} not found in inputs")
1273 return self
1275 packaged_by: List[Author] = Field(default_factory=list)
1276 """The persons that have packaged and uploaded this model.
1277 Only required if those persons differ from the `authors`."""
1279 parent: Optional[LinkedModel] = None
1280 """The model from which this model is derived, e.g. by fine-tuning the weights."""
1282 @field_validator("parent", mode="before")
1283 @classmethod
1284 def ignore_url_parent(cls, parent: Any):
1285 if isinstance(parent, dict):
1286 return None
1288 else:
1289 return parent
1291 run_mode: Optional[RunMode] = None
1292 """Custom run mode for this model: for more complex prediction procedures like test time
1293 data augmentation that currently cannot be expressed in the specification.
1294 No standard run modes are defined yet."""
1296 sample_inputs: List[ImportantFileSource] = Field(default_factory=list)
1297 """URLs/relative paths to sample inputs to illustrate possible inputs for the model,
1298 for example stored as PNG or TIFF images.
1299 The sample files primarily serve to inform a human user about an example use case"""
1301 sample_outputs: List[ImportantFileSource] = Field(default_factory=list)
1302 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`."""
1304 test_inputs: NotEmpty[
1305 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]]
1306 ]
1307 """Test input tensors compatible with the `inputs` description for a **single test case**.
1308 This means if your model has more than one input, you should provide one URL/relative path for each input.
1309 Each test input should be a file with an ndarray in
1310 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1311 The extension must be '.npy'."""
1313 test_outputs: NotEmpty[
1314 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]]
1315 ]
1316 """Analog to `test_inputs`."""
1318 timestamp: Datetime
1319 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
1320 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat)."""
1322 training_data: Union[LinkedDataset, DatasetDescr, None] = None
1323 """The dataset used to train this model"""
1325 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
1326 """The weights for this model.
1327 Weights can be given for different formats, but should otherwise be equivalent.
1328 The available weight formats determine which consumers can use this model."""
1330 @model_validator(mode="before")
1331 @classmethod
1332 def _convert_from_older_format(
1333 cls, data: BioimageioYamlContent, /
1334 ) -> BioimageioYamlContent:
1335 convert_from_older_format(data)
1336 return data
1338 def get_input_test_arrays(self) -> List[NDArray[Any]]:
1339 data = [load_array(download(ipt).path) for ipt in self.test_inputs]
1340 assert all(isinstance(d, np.ndarray) for d in data)
1341 return data
1343 def get_output_test_arrays(self) -> List[NDArray[Any]]:
1344 data = [load_array(download(out).path) for out in self.test_outputs]
1345 assert all(isinstance(d, np.ndarray) for d in data)
1346 return data