Coverage for bioimageio/spec/model/v0_4.py: 90%
595 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
1from __future__ import annotations
3import collections.abc
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8 ClassVar,
9 Dict,
10 List,
11 Literal,
12 Optional,
13 Sequence,
14 Tuple,
15 Type,
16 Union,
17 cast,
18)
20import numpy as np
21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf
22from numpy.typing import NDArray
23from pydantic import (
24 AllowInfNan,
25 Discriminator,
26 Field,
27 RootModel,
28 SerializationInfo,
29 SerializerFunctionWrapHandler,
30 StringConstraints,
31 TypeAdapter,
32 ValidationInfo,
33 WrapSerializer,
34 field_validator,
35 model_validator,
36)
37from typing_extensions import Annotated, Self, assert_never, get_args
39from .._internal.common_nodes import (
40 KwargsNode,
41 Node,
42 NodeWithExplicitlySetFields,
43)
44from .._internal.constants import SHA256_HINT
45from .._internal.field_validation import validate_unique_entries
46from .._internal.field_warning import issue_warning, warn
47from .._internal.io import BioimageioYamlContent, WithSuffix
48from .._internal.io import FileDescr as FileDescr
49from .._internal.io_basics import Sha256 as Sha256
50from .._internal.io_packaging import include_in_package
51from .._internal.io_utils import load_array
52from .._internal.packaging_context import packaging_context_var
53from .._internal.types import Datetime as Datetime
54from .._internal.types import FileSource_, LowerCaseIdentifier
55from .._internal.types import Identifier as Identifier
56from .._internal.types import LicenseId as LicenseId
57from .._internal.types import NotEmpty as NotEmpty
58from .._internal.url import HttpUrl as HttpUrl
59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode
60from .._internal.validator_annotations import AfterValidator, RestrictCharacters
61from .._internal.version_type import Version as Version
62from .._internal.warning_levels import ALERT, INFO
63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS
64from ..dataset.v0_2 import DatasetDescr as DatasetDescr
65from ..dataset.v0_2 import LinkedDataset as LinkedDataset
66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr
67from ..generic.v0_2 import Author as Author
68from ..generic.v0_2 import BadgeDescr as BadgeDescr
69from ..generic.v0_2 import CiteEntry as CiteEntry
70from ..generic.v0_2 import Doi as Doi
71from ..generic.v0_2 import GenericModelDescrBase
72from ..generic.v0_2 import LinkedResource as LinkedResource
73from ..generic.v0_2 import Maintainer as Maintainer
74from ..generic.v0_2 import OrcidId as OrcidId
75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath
76from ..generic.v0_2 import ResourceId as ResourceId
77from ..generic.v0_2 import Uploader as Uploader
78from ._v0_4_converter import convert_from_older_format
81class ModelId(ResourceId):
82 pass
85AxesStr = Annotated[
86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries)
87]
88AxesInCZYX = Annotated[
89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries)
90]
92PostprocessingName = Literal[
93 "binarize",
94 "clip",
95 "scale_linear",
96 "sigmoid",
97 "zero_mean_unit_variance",
98 "scale_range",
99 "scale_mean_variance",
100]
101PreprocessingName = Literal[
102 "binarize",
103 "clip",
104 "scale_linear",
105 "sigmoid",
106 "zero_mean_unit_variance",
107 "scale_range",
108]
111class TensorName(LowerCaseIdentifier):
112 pass
115class CallableFromDepencencyNode(Node):
116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier)
118 module_name: str
119 """The Python module that implements **callable_name**."""
121 @field_validator("module_name", mode="after")
122 def _check_submodules(cls, module_name: str) -> str:
123 for submod in module_name.split("."):
124 _ = cls._submodule_adapter.validate_python(submod)
126 return module_name
128 callable_name: Identifier
129 """The callable Python identifier implemented in module **module_name**."""
132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]):
133 _inner_node_class = CallableFromDepencencyNode
134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
135 Annotated[
136 str,
137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"),
138 ]
139 ]
141 @classmethod
142 def _get_data(cls, valid_string_data: str):
143 *mods, callname = valid_string_data.split(".")
144 return dict(module_name=".".join(mods), callable_name=callname)
146 @property
147 def module_name(self):
148 """The Python module that implements **callable_name**."""
149 return self._inner_node.module_name
151 @property
152 def callable_name(self):
153 """The callable Python identifier implemented in module **module_name**."""
154 return self._inner_node.callable_name
157class CallableFromFileNode(Node):
158 source_file: Annotated[
159 Union[RelativeFilePath, HttpUrl],
160 Field(union_mode="left_to_right"),
161 include_in_package,
162 ]
163 """The Python source file that implements **callable_name**."""
164 callable_name: Identifier
165 """The callable Python identifier implemented in **source_file**."""
168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]):
169 _inner_node_class = CallableFromFileNode
170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
171 Annotated[
172 str,
173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
174 ]
175 ]
177 @classmethod
178 def _get_data(cls, valid_string_data: str):
179 *file_parts, callname = valid_string_data.split(":")
180 return dict(source_file=":".join(file_parts), callable_name=callname)
182 @property
183 def source_file(self):
184 """The Python source file that implements **callable_name**."""
185 return self._inner_node.source_file
187 @property
188 def callable_name(self):
189 """The callable Python identifier implemented in **source_file**."""
190 return self._inner_node.callable_name
193CustomCallable = Annotated[
194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right")
195]
198class DependenciesNode(Node):
199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])]
200 """Dependency manager"""
202 file: FileSource_
203 """Dependency file"""
206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]):
207 _inner_node_class = DependenciesNode
208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
209 Annotated[
210 str,
211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
212 ]
213 ]
215 @classmethod
216 def _get_data(cls, valid_string_data: str):
217 manager, *file_parts = valid_string_data.split(":")
218 return dict(manager=manager, file=":".join(file_parts))
220 @property
221 def manager(self):
222 """Dependency manager"""
223 return self._inner_node.manager
225 @property
226 def file(self):
227 """Dependency file"""
228 return self._inner_node.file
231WeightsFormat = Literal[
232 "keras_hdf5",
233 "onnx",
234 "pytorch_state_dict",
235 "tensorflow_js",
236 "tensorflow_saved_model_bundle",
237 "torchscript",
238]
241class WeightsEntryDescrBase(FileDescr):
242 type: ClassVar[WeightsFormat]
243 weights_format_name: ClassVar[str] # human readable
245 source: FileSource_
246 """The weights file."""
248 attachments: Annotated[
249 Union[AttachmentsDescr, None],
250 warn(None, "Weights entry depends on additional attachments.", ALERT),
251 ] = None
252 """Attachments that are specific to this weights entry."""
254 authors: Union[List[Author], None] = None
255 """Authors
256 Either the person(s) that have trained this model resulting in the original weights file.
257 (If this is the initial weights entry, i.e. it does not have a `parent`)
258 Or the person(s) who have converted the weights to this weights format.
259 (If this is a child weight, i.e. it has a `parent` field)
260 """
262 dependencies: Annotated[
263 Optional[Dependencies],
264 warn(
265 None,
266 "Custom dependencies ({value}) specified. Avoid this whenever possible "
267 + "to allow execution in a wider range of software environments.",
268 ),
269 Field(
270 examples=[
271 "conda:environment.yaml",
272 "maven:./pom.xml",
273 "pip:./requirements.txt",
274 ]
275 ),
276 ] = None
277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`."""
279 parent: Annotated[
280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
281 ] = None
282 """The source weights these weights were converted from.
283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
285 All weight entries except one (the initial set of weights resulting from training the model),
286 need to have this field."""
288 @model_validator(mode="after")
289 def check_parent_is_not_self(self) -> Self:
290 if self.type == self.parent:
291 raise ValueError("Weights entry can't be it's own parent.")
293 return self
296class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
297 type = "keras_hdf5"
298 weights_format_name: ClassVar[str] = "Keras HDF5"
299 tensorflow_version: Optional[Version] = None
300 """TensorFlow version used to create these weights"""
302 @field_validator("tensorflow_version", mode="after")
303 @classmethod
304 def _tfv(cls, value: Any):
305 if value is None:
306 issue_warning(
307 "missing. Please specify the TensorFlow version"
308 + " these weights were created with.",
309 value=value,
310 severity=ALERT,
311 field="tensorflow_version",
312 )
313 return value
316class OnnxWeightsDescr(WeightsEntryDescrBase):
317 type = "onnx"
318 weights_format_name: ClassVar[str] = "ONNX"
319 opset_version: Optional[Annotated[int, Ge(7)]] = None
320 """ONNX opset version"""
322 @field_validator("opset_version", mode="after")
323 @classmethod
324 def _ov(cls, value: Any):
325 if value is None:
326 issue_warning(
327 "Missing ONNX opset version (aka ONNX opset number). "
328 + "Please specify the ONNX opset version these weights were created"
329 + " with.",
330 value=value,
331 severity=ALERT,
332 field="opset_version",
333 )
334 return value
337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
338 type = "pytorch_state_dict"
339 weights_format_name: ClassVar[str] = "Pytorch State Dict"
340 architecture: CustomCallable = Field(
341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"]
342 )
343 """callable returning a torch.nn.Module instance.
344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`.
345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`."""
347 architecture_sha256: Annotated[
348 Optional[Sha256],
349 Field(
350 description=(
351 "The SHA256 of the architecture source file, if the architecture is not"
352 " defined in a module listed in `dependencies`\n"
353 )
354 + SHA256_HINT,
355 ),
356 ] = None
357 """The SHA256 of the architecture source file,
358 if the architecture is not defined in a module listed in `dependencies`"""
360 @model_validator(mode="after")
361 def check_architecture_sha256(self) -> Self:
362 if isinstance(self.architecture, CallableFromFile):
363 if self.architecture_sha256 is None:
364 raise ValueError(
365 "Missing required `architecture_sha256` for `architecture` with"
366 + " source file."
367 )
368 elif self.architecture_sha256 is not None:
369 raise ValueError(
370 "Got `architecture_sha256` for architecture that does not have a source"
371 + " file."
372 )
374 return self
376 kwargs: Dict[str, Any] = Field(
377 default_factory=cast(Callable[[], Dict[str, Any]], dict)
378 )
379 """key word arguments for the `architecture` callable"""
381 pytorch_version: Optional[Version] = None
382 """Version of the PyTorch library used.
383 If `depencencies` is specified it should include pytorch and the verison has to match.
384 (`dependencies` overrules `pytorch_version`)"""
386 @field_validator("pytorch_version", mode="after")
387 @classmethod
388 def _ptv(cls, value: Any):
389 if value is None:
390 issue_warning(
391 "missing. Please specify the PyTorch version these"
392 + " PyTorch state dict weights were created with.",
393 value=value,
394 severity=ALERT,
395 field="pytorch_version",
396 )
397 return value
400class TorchscriptWeightsDescr(WeightsEntryDescrBase):
401 type = "torchscript"
402 weights_format_name: ClassVar[str] = "TorchScript"
403 pytorch_version: Optional[Version] = None
404 """Version of the PyTorch library used."""
406 @field_validator("pytorch_version", mode="after")
407 @classmethod
408 def _ptv(cls, value: Any):
409 if value is None:
410 issue_warning(
411 "missing. Please specify the PyTorch version these"
412 + " Torchscript weights were created with.",
413 value=value,
414 severity=ALERT,
415 field="pytorch_version",
416 )
417 return value
420class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
421 type = "tensorflow_js"
422 weights_format_name: ClassVar[str] = "Tensorflow.js"
423 tensorflow_version: Optional[Version] = None
424 """Version of the TensorFlow library used."""
426 @field_validator("tensorflow_version", mode="after")
427 @classmethod
428 def _tfv(cls, value: Any):
429 if value is None:
430 issue_warning(
431 "missing. Please specify the TensorFlow version"
432 + " these TensorflowJs weights were created with.",
433 value=value,
434 severity=ALERT,
435 field="tensorflow_version",
436 )
437 return value
439 source: FileSource_
440 """The multi-file weights.
441 All required files/folders should be a zip archive."""
444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
445 type = "tensorflow_saved_model_bundle"
446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
447 tensorflow_version: Optional[Version] = None
448 """Version of the TensorFlow library used."""
450 @field_validator("tensorflow_version", mode="after")
451 @classmethod
452 def _tfv(cls, value: Any):
453 if value is None:
454 issue_warning(
455 "missing. Please specify the TensorFlow version"
456 + " these Tensorflow saved model bundle weights were created with.",
457 value=value,
458 severity=ALERT,
459 field="tensorflow_version",
460 )
461 return value
464class WeightsDescr(Node):
465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
466 onnx: Optional[OnnxWeightsDescr] = None
467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
470 None
471 )
472 torchscript: Optional[TorchscriptWeightsDescr] = None
474 @model_validator(mode="after")
475 def check_one_entry(self) -> Self:
476 if all(
477 entry is None
478 for entry in [
479 self.keras_hdf5,
480 self.onnx,
481 self.pytorch_state_dict,
482 self.tensorflow_js,
483 self.tensorflow_saved_model_bundle,
484 self.torchscript,
485 ]
486 ):
487 raise ValueError("Missing weights entry")
489 return self
491 def __getitem__(
492 self,
493 key: WeightsFormat,
494 ):
495 if key == "keras_hdf5":
496 ret = self.keras_hdf5
497 elif key == "onnx":
498 ret = self.onnx
499 elif key == "pytorch_state_dict":
500 ret = self.pytorch_state_dict
501 elif key == "tensorflow_js":
502 ret = self.tensorflow_js
503 elif key == "tensorflow_saved_model_bundle":
504 ret = self.tensorflow_saved_model_bundle
505 elif key == "torchscript":
506 ret = self.torchscript
507 else:
508 raise KeyError(key)
510 if ret is None:
511 raise KeyError(key)
513 return ret
515 @property
516 def available_formats(self):
517 return {
518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
519 **({} if self.onnx is None else {"onnx": self.onnx}),
520 **(
521 {}
522 if self.pytorch_state_dict is None
523 else {"pytorch_state_dict": self.pytorch_state_dict}
524 ),
525 **(
526 {}
527 if self.tensorflow_js is None
528 else {"tensorflow_js": self.tensorflow_js}
529 ),
530 **(
531 {}
532 if self.tensorflow_saved_model_bundle is None
533 else {
534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
535 }
536 ),
537 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
538 }
540 @property
541 def missing_formats(self):
542 return {
543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
544 }
547class ParameterizedInputShape(Node):
548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
550 min: NotEmpty[List[int]]
551 """The minimum input shape"""
553 step: NotEmpty[List[int]]
554 """The minimum shape change"""
556 def __len__(self) -> int:
557 return len(self.min)
559 @model_validator(mode="after")
560 def matching_lengths(self) -> Self:
561 if len(self.min) != len(self.step):
562 raise ValueError("`min` and `step` required to have the same length")
564 return self
567class ImplicitOutputShape(Node):
568 """Output tensor shape depending on an input tensor shape.
569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`"""
571 reference_tensor: TensorName
572 """Name of the reference tensor."""
574 scale: NotEmpty[List[Optional[float]]]
575 """output_pix/input_pix for each dimension.
576 'null' values indicate new dimensions, whose length is defined by 2*`offset`"""
578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]]
579 """Position of origin wrt to input."""
581 def __len__(self) -> int:
582 return len(self.scale)
584 @model_validator(mode="after")
585 def matching_lengths(self) -> Self:
586 if len(self.scale) != len(self.offset):
587 raise ValueError(
588 f"scale {self.scale} has to have same length as offset {self.offset}!"
589 )
590 # if we have an expanded dimension, make sure that it's offet is not zero
591 for sc, off in zip(self.scale, self.offset):
592 if sc is None and not off:
593 raise ValueError("`offset` must not be zero if `scale` is none/zero")
595 return self
598class TensorDescrBase(Node):
599 name: TensorName
600 """Tensor name. No duplicates are allowed."""
602 description: str = ""
604 axes: AxesStr
605 """Axes identifying characters. Same length and order as the axes in `shape`.
606 | axis | description |
607 | --- | --- |
608 | b | batch (groups multiple samples) |
609 | i | instance/index/element |
610 | t | time |
611 | c | channel |
612 | z | spatial dimension z |
613 | y | spatial dimension y |
614 | x | spatial dimension x |
615 """
617 data_range: Optional[
618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]]
619 ] = None
620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
621 If not specified, the full data range that can be expressed in `data_type` is allowed."""
624class ProcessingKwargs(KwargsNode):
625 """base class for pre-/postprocessing key word arguments"""
628class ProcessingDescrBase(NodeWithExplicitlySetFields):
629 """processing base class"""
632class BinarizeKwargs(ProcessingKwargs):
633 """key word arguments for `BinarizeDescr`"""
635 threshold: float
636 """The fixed threshold"""
639class BinarizeDescr(ProcessingDescrBase):
640 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`.
641 Values above the threshold will be set to one, values below the threshold to zero.
642 """
644 implemented_name: ClassVar[Literal["binarize"]] = "binarize"
645 if TYPE_CHECKING:
646 name: Literal["binarize"] = "binarize"
647 else:
648 name: Literal["binarize"]
650 kwargs: BinarizeKwargs
653class ClipKwargs(ProcessingKwargs):
654 """key word arguments for `ClipDescr`"""
656 min: float
657 """minimum value for clipping"""
658 max: float
659 """maximum value for clipping"""
662class ClipDescr(ProcessingDescrBase):
663 """Clip tensor values to a range.
665 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min`
666 and above `ClipKwargs.max` to `ClipKwargs.max`.
667 """
669 implemented_name: ClassVar[Literal["clip"]] = "clip"
670 if TYPE_CHECKING:
671 name: Literal["clip"] = "clip"
672 else:
673 name: Literal["clip"]
675 kwargs: ClipKwargs
678class ScaleLinearKwargs(ProcessingKwargs):
679 """key word arguments for `ScaleLinearDescr`"""
681 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
682 """The subset of axes to scale jointly.
683 For example xy to scale the two image axes for 2d data jointly."""
685 gain: Union[float, List[float]] = 1.0
686 """multiplicative factor"""
688 offset: Union[float, List[float]] = 0.0
689 """additive term"""
691 @model_validator(mode="after")
692 def either_gain_or_offset(self) -> Self:
693 if (
694 self.gain == 1.0
695 or isinstance(self.gain, list)
696 and all(g == 1.0 for g in self.gain)
697 ) and (
698 self.offset == 0.0
699 or isinstance(self.offset, list)
700 and all(off == 0.0 for off in self.offset)
701 ):
702 raise ValueError(
703 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !="
704 + " 0.0."
705 )
707 return self
710class ScaleLinearDescr(ProcessingDescrBase):
711 """Fixed linear scaling."""
713 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear"
714 if TYPE_CHECKING:
715 name: Literal["scale_linear"] = "scale_linear"
716 else:
717 name: Literal["scale_linear"]
719 kwargs: ScaleLinearKwargs
722class SigmoidDescr(ProcessingDescrBase):
723 """The logistic sigmoid funciton, a.k.a. expit function."""
725 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid"
726 if TYPE_CHECKING:
727 name: Literal["sigmoid"] = "sigmoid"
728 else:
729 name: Literal["sigmoid"]
731 @property
732 def kwargs(self) -> ProcessingKwargs:
733 """empty kwargs"""
734 return ProcessingKwargs()
737class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
738 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
740 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed"
741 """Mode for computing mean and variance.
742 | mode | description |
743 | ----------- | ------------------------------------ |
744 | fixed | Fixed values for mean and variance |
745 | per_dataset | Compute for the entire dataset |
746 | per_sample | Compute for each sample individually |
747 """
748 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
749 """The subset of axes to normalize jointly.
750 For example `xy` to normalize the two image axes for 2d data jointly."""
752 mean: Annotated[
753 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)])
754 ] = None
755 """The mean value(s) to use for `mode: fixed`.
756 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`."""
757 # todo: check if means match input axes (for mode 'fixed')
759 std: Annotated[
760 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)])
761 ] = None
762 """The standard deviation values to use for `mode: fixed`. Analogous to mean."""
764 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
765 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
767 @model_validator(mode="after")
768 def mean_and_std_match_mode(self) -> Self:
769 if self.mode == "fixed" and (self.mean is None or self.std is None):
770 raise ValueError("`mean` and `std` are required for `mode: fixed`.")
771 elif self.mode != "fixed" and (self.mean is not None or self.std is not None):
772 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`")
774 return self
777class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
778 """Subtract mean and divide by variance."""
780 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = (
781 "zero_mean_unit_variance"
782 )
783 if TYPE_CHECKING:
784 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
785 else:
786 name: Literal["zero_mean_unit_variance"]
788 kwargs: ZeroMeanUnitVarianceKwargs
791class ScaleRangeKwargs(ProcessingKwargs):
792 """key word arguments for `ScaleRangeDescr`
794 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
795 this processing step normalizes data to the [0, 1] intervall.
796 For other percentiles the normalized values will partially be outside the [0, 1]
797 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
798 normalized values to a range.
799 """
801 mode: Literal["per_dataset", "per_sample"]
802 """Mode for computing percentiles.
803 | mode | description |
804 | ----------- | ------------------------------------ |
805 | per_dataset | compute for the entire dataset |
806 | per_sample | compute for each sample individually |
807 """
808 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
809 """The subset of axes to normalize jointly.
810 For example xy to normalize the two image axes for 2d data jointly."""
812 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0
813 """The lower percentile used to determine the value to align with zero."""
815 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0
816 """The upper percentile used to determine the value to align with one.
817 Has to be bigger than `min_percentile`.
818 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
819 accepting percentiles specified in the range 0.0 to 1.0."""
821 @model_validator(mode="after")
822 def min_smaller_max(self, info: ValidationInfo) -> Self:
823 if self.min_percentile >= self.max_percentile:
824 raise ValueError(
825 f"min_percentile {self.min_percentile} >= max_percentile"
826 + f" {self.max_percentile}"
827 )
829 return self
831 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
832 """Epsilon for numeric stability.
833 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
834 with `v_lower,v_upper` values at the respective percentiles."""
836 reference_tensor: Optional[TensorName] = None
837 """Tensor name to compute the percentiles from. Default: The tensor itself.
838 For any tensor in `inputs` only input tensor references are allowed.
839 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
842class ScaleRangeDescr(ProcessingDescrBase):
843 """Scale with percentiles."""
845 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range"
846 if TYPE_CHECKING:
847 name: Literal["scale_range"] = "scale_range"
848 else:
849 name: Literal["scale_range"]
851 kwargs: ScaleRangeKwargs
854class ScaleMeanVarianceKwargs(ProcessingKwargs):
855 """key word arguments for `ScaleMeanVarianceDescr`"""
857 mode: Literal["per_dataset", "per_sample"]
858 """Mode for computing mean and variance.
859 | mode | description |
860 | ----------- | ------------------------------------ |
861 | per_dataset | Compute for the entire dataset |
862 | per_sample | Compute for each sample individually |
863 """
865 reference_tensor: TensorName
866 """Name of tensor to match."""
868 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
869 """The subset of axes to scale jointly.
870 For example xy to normalize the two image axes for 2d data jointly.
871 Default: scale all non-batch axes jointly."""
873 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
874 """Epsilon for numeric stability:
875 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
878class ScaleMeanVarianceDescr(ProcessingDescrBase):
879 """Scale the tensor s.t. its mean and variance match a reference tensor."""
881 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
882 if TYPE_CHECKING:
883 name: Literal["scale_mean_variance"] = "scale_mean_variance"
884 else:
885 name: Literal["scale_mean_variance"]
887 kwargs: ScaleMeanVarianceKwargs
890PreprocessingDescr = Annotated[
891 Union[
892 BinarizeDescr,
893 ClipDescr,
894 ScaleLinearDescr,
895 SigmoidDescr,
896 ZeroMeanUnitVarianceDescr,
897 ScaleRangeDescr,
898 ],
899 Discriminator("name"),
900]
901PostprocessingDescr = Annotated[
902 Union[
903 BinarizeDescr,
904 ClipDescr,
905 ScaleLinearDescr,
906 SigmoidDescr,
907 ZeroMeanUnitVarianceDescr,
908 ScaleRangeDescr,
909 ScaleMeanVarianceDescr,
910 ],
911 Discriminator("name"),
912]
915class InputTensorDescr(TensorDescrBase):
916 data_type: Literal["float32", "uint8", "uint16"]
917 """For now an input tensor is expected to be given as `float32`.
918 The data flow in bioimage.io models is explained
919 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
921 shape: Annotated[
922 Union[Sequence[int], ParameterizedInputShape],
923 Field(
924 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))]
925 ),
926 ]
927 """Specification of input tensor shape."""
929 preprocessing: List[PreprocessingDescr] = Field(
930 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr]
931 Callable[[], List[PreprocessingDescr]], list
932 )
933 )
934 """Description of how this input should be preprocessed."""
936 @model_validator(mode="after")
937 def zero_batch_step_and_one_batch_size(self) -> Self:
938 bidx = self.axes.find("b")
939 if bidx == -1:
940 return self
942 if isinstance(self.shape, ParameterizedInputShape):
943 step = self.shape.step
944 shape = self.shape.min
945 if step[bidx] != 0:
946 raise ValueError(
947 "Input shape step has to be zero in the batch dimension (the batch"
948 + " dimension can always be increased, but `step` should specify how"
949 + " to increase the minimal shape to find the largest single batch"
950 + " shape)"
951 )
952 else:
953 shape = self.shape
955 if shape[bidx] != 1:
956 raise ValueError("Input shape has to be 1 in the batch dimension b.")
958 return self
960 @model_validator(mode="after")
961 def validate_preprocessing_kwargs(self) -> Self:
962 for p in self.preprocessing:
963 kwargs_axes = p.kwargs.get("axes")
964 if isinstance(kwargs_axes, str) and any(
965 a not in self.axes for a in kwargs_axes
966 ):
967 raise ValueError("`kwargs.axes` needs to be subset of `axes`")
969 return self
972class OutputTensorDescr(TensorDescrBase):
973 data_type: Literal[
974 "float32",
975 "float64",
976 "uint8",
977 "int8",
978 "uint16",
979 "int16",
980 "uint32",
981 "int32",
982 "uint64",
983 "int64",
984 "bool",
985 ]
986 """Data type.
987 The data flow in bioimage.io models is explained
988 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
990 shape: Union[Sequence[int], ImplicitOutputShape]
991 """Output tensor shape."""
993 halo: Optional[Sequence[int]] = None
994 """The `halo` that should be cropped from the output tensor to avoid boundary effects.
995 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`.
996 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead."""
998 postprocessing: List[PostprocessingDescr] = Field(
999 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1000 )
1001 """Description of how this output should be postprocessed."""
1003 @model_validator(mode="after")
1004 def matching_halo_length(self) -> Self:
1005 if self.halo and len(self.halo) != len(self.shape):
1006 raise ValueError(
1007 f"halo {self.halo} has to have same length as shape {self.shape}!"
1008 )
1010 return self
1012 @model_validator(mode="after")
1013 def validate_postprocessing_kwargs(self) -> Self:
1014 for p in self.postprocessing:
1015 kwargs_axes = p.kwargs.get("axes", "")
1016 if not isinstance(kwargs_axes, str):
1017 raise ValueError(f"Expected {kwargs_axes} to be a string")
1019 if any(a not in self.axes for a in kwargs_axes):
1020 raise ValueError("`kwargs.axes` needs to be subset of axes")
1022 return self
1025KnownRunMode = Literal["deepimagej"]
1028class RunMode(Node):
1029 name: Annotated[
1030 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.")
1031 ]
1032 """Run mode name"""
1034 kwargs: Dict[str, Any] = Field(
1035 default_factory=cast(Callable[[], Dict[str, Any]], dict)
1036 )
1037 """Run mode specific key word arguments"""
1040class LinkedModel(Node):
1041 """Reference to a bioimage.io model."""
1043 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])]
1044 """A valid model `id` from the bioimage.io collection."""
1046 version_number: Optional[int] = None
1047 """version number (n-th published version, not the semantic version) of linked model"""
1050def package_weights(
1051 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr]
1052 handler: SerializerFunctionWrapHandler,
1053 info: SerializationInfo,
1054):
1055 ctxt = packaging_context_var.get()
1056 if ctxt is not None and ctxt.weights_priority_order is not None:
1057 for wf in ctxt.weights_priority_order:
1058 w = getattr(value, wf, None)
1059 if w is not None:
1060 break
1061 else:
1062 raise ValueError(
1063 "None of the weight formats in `weights_priority_order`"
1064 + f" ({ctxt.weights_priority_order}) is present in the given model."
1065 )
1067 assert isinstance(w, Node), type(w)
1068 # construct WeightsDescr with new single weight format entry
1069 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"})
1070 value = value.model_construct(None, **{wf: new_w})
1072 return handler(
1073 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs
1074 )
1077class ModelDescr(GenericModelDescrBase):
1078 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
1080 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
1081 """
1083 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10"
1084 if TYPE_CHECKING:
1085 format_version: Literal["0.4.10"] = "0.4.10"
1086 else:
1087 format_version: Literal["0.4.10"]
1088 """Version of the bioimage.io model description specification used.
1089 When creating a new model always use the latest micro/patch version described here.
1090 The `format_version` is important for any consumer software to understand how to parse the fields.
1091 """
1093 implemented_type: ClassVar[Literal["model"]] = "model"
1094 if TYPE_CHECKING:
1095 type: Literal["model"] = "model"
1096 else:
1097 type: Literal["model"]
1098 """Specialized resource type 'model'"""
1100 id: Optional[ModelId] = None
1101 """bioimage.io-wide unique resource identifier
1102 assigned by bioimage.io; version **un**specific."""
1104 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
1105 List[Author]
1106 ]
1107 """The authors are the creators of the model RDF and the primary points of contact."""
1109 documentation: Annotated[
1110 FileSource_,
1111 Field(
1112 examples=[
1113 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
1114 "README.md",
1115 ],
1116 ),
1117 ]
1118 """URL or relative path to a markdown file with additional documentation.
1119 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
1120 The documentation should include a '[#[#]]# Validation' (sub)section
1121 with details on how to quantitatively validate the model on unseen data."""
1123 inputs: NotEmpty[List[InputTensorDescr]]
1124 """Describes the input tensors expected by this model."""
1126 license: Annotated[
1127 Union[LicenseId, str],
1128 warn(LicenseId, "Unknown license id '{value}'."),
1129 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]),
1130 ]
1131 """A [SPDX license identifier](https://spdx.org/licenses/).
1132 We do notsupport custom license beyond the SPDX license list, if you need that please
1133 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose
1134 ) to discuss your intentions with the community."""
1136 name: Annotated[
1137 str,
1138 MinLen(1),
1139 warn(MinLen(5), "Name shorter than 5 characters.", INFO),
1140 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
1141 ]
1142 """A human-readable name of this model.
1143 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters."""
1145 outputs: NotEmpty[List[OutputTensorDescr]]
1146 """Describes the output tensors."""
1148 @field_validator("inputs", "outputs")
1149 @classmethod
1150 def unique_tensor_descr_names(
1151 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]]
1152 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]:
1153 unique_names = {str(v.name) for v in value}
1154 if len(unique_names) != len(value):
1155 raise ValueError("Duplicate tensor descriptor names")
1157 return value
1159 @model_validator(mode="after")
1160 def unique_io_names(self) -> Self:
1161 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s}
1162 if len(unique_names) != (len(self.inputs) + len(self.outputs)):
1163 raise ValueError("Duplicate tensor descriptor names across inputs/outputs")
1165 return self
1167 @model_validator(mode="after")
1168 def minimum_shape2valid_output(self) -> Self:
1169 tensors_by_name: Dict[
1170 TensorName, Union[InputTensorDescr, OutputTensorDescr]
1171 ] = {t.name: t for t in self.inputs + self.outputs}
1173 for out in self.outputs:
1174 if isinstance(out.shape, ImplicitOutputShape):
1175 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape)
1176 ndim_out_ref = len(
1177 [scale for scale in out.shape.scale if scale is not None]
1178 )
1179 if ndim_ref != ndim_out_ref:
1180 expanded_dim_note = (
1181 " Note that expanded dimensions (`scale`: null) are not"
1182 + f" counted for {out.name}'sdimensionality here."
1183 if None in out.shape.scale
1184 else ""
1185 )
1186 raise ValueError(
1187 f"Referenced tensor '{out.shape.reference_tensor}' with"
1188 + f" {ndim_ref} dimensions does not match output tensor"
1189 + f" '{out.name}' with"
1190 + f" {ndim_out_ref} dimensions.{expanded_dim_note}"
1191 )
1193 min_out_shape = self._get_min_shape(out, tensors_by_name)
1194 if out.halo:
1195 halo = out.halo
1196 halo_msg = f" for halo {out.halo}"
1197 else:
1198 halo = [0] * len(min_out_shape)
1199 halo_msg = ""
1201 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]):
1202 raise ValueError(
1203 f"Minimal shape {min_out_shape} of output {out.name} is too"
1204 + f" small{halo_msg}."
1205 )
1207 return self
1209 @classmethod
1210 def _get_min_shape(
1211 cls,
1212 t: Union[InputTensorDescr, OutputTensorDescr],
1213 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]],
1214 ) -> Sequence[int]:
1215 """output with subtracted halo has to result in meaningful output even for the minimal input
1216 see https://github.com/bioimage-io/spec-bioimage-io/issues/392
1217 """
1218 if isinstance(t.shape, collections.abc.Sequence):
1219 return t.shape
1220 elif isinstance(t.shape, ParameterizedInputShape):
1221 return t.shape.min
1222 elif isinstance(t.shape, ImplicitOutputShape):
1223 pass
1224 else:
1225 assert_never(t.shape)
1227 ref_shape = cls._get_min_shape(
1228 tensors_by_name[t.shape.reference_tensor], tensors_by_name
1229 )
1231 if None not in t.shape.scale:
1232 scale: Sequence[float, ...] = t.shape.scale # type: ignore
1233 else:
1234 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None]
1235 new_ref_shape: List[int] = []
1236 for idx in range(len(t.shape.scale)):
1237 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims)
1238 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx])
1240 ref_shape = new_ref_shape
1241 assert len(ref_shape) == len(t.shape.scale)
1242 scale = [0.0 if sc is None else sc for sc in t.shape.scale]
1244 offset = t.shape.offset
1245 assert len(offset) == len(scale)
1246 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)]
1248 @model_validator(mode="after")
1249 def validate_tensor_references_in_inputs(self) -> Self:
1250 for t in self.inputs:
1251 for proc in t.preprocessing:
1252 if "reference_tensor" not in proc.kwargs:
1253 continue
1255 ref_tensor = proc.kwargs["reference_tensor"]
1256 if ref_tensor is not None and str(ref_tensor) not in {
1257 str(t.name) for t in self.inputs
1258 }:
1259 raise ValueError(f"'{ref_tensor}' not found in inputs")
1261 if ref_tensor == t.name:
1262 raise ValueError(
1263 f"invalid self reference for preprocessing of tensor {t.name}"
1264 )
1266 return self
1268 @model_validator(mode="after")
1269 def validate_tensor_references_in_outputs(self) -> Self:
1270 for t in self.outputs:
1271 for proc in t.postprocessing:
1272 if "reference_tensor" not in proc.kwargs:
1273 continue
1274 ref_tensor = proc.kwargs["reference_tensor"]
1275 if ref_tensor is not None and str(ref_tensor) not in {
1276 str(t.name) for t in self.inputs
1277 }:
1278 raise ValueError(f"{ref_tensor} not found in inputs")
1280 return self
1282 packaged_by: List[Author] = Field(
1283 default_factory=cast(Callable[[], List[Author]], list)
1284 )
1285 """The persons that have packaged and uploaded this model.
1286 Only required if those persons differ from the `authors`."""
1288 parent: Optional[LinkedModel] = None
1289 """The model from which this model is derived, e.g. by fine-tuning the weights."""
1291 @field_validator("parent", mode="before")
1292 @classmethod
1293 def ignore_url_parent(cls, parent: Any):
1294 if isinstance(parent, dict):
1295 return None
1297 else:
1298 return parent
1300 run_mode: Optional[RunMode] = None
1301 """Custom run mode for this model: for more complex prediction procedures like test time
1302 data augmentation that currently cannot be expressed in the specification.
1303 No standard run modes are defined yet."""
1305 sample_inputs: List[FileSource_] = Field(
1306 default_factory=cast(Callable[[], List[FileSource_]], list)
1307 )
1308 """URLs/relative paths to sample inputs to illustrate possible inputs for the model,
1309 for example stored as PNG or TIFF images.
1310 The sample files primarily serve to inform a human user about an example use case"""
1312 sample_outputs: List[FileSource_] = Field(
1313 default_factory=cast(Callable[[], List[FileSource_]], list)
1314 )
1315 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`."""
1317 test_inputs: NotEmpty[
1318 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1319 ]
1320 """Test input tensors compatible with the `inputs` description for a **single test case**.
1321 This means if your model has more than one input, you should provide one URL/relative path for each input.
1322 Each test input should be a file with an ndarray in
1323 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1324 The extension must be '.npy'."""
1326 test_outputs: NotEmpty[
1327 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1328 ]
1329 """Analog to `test_inputs`."""
1331 timestamp: Datetime
1332 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
1333 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat)."""
1335 training_data: Union[LinkedDataset, DatasetDescr, None] = None
1336 """The dataset used to train this model"""
1338 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
1339 """The weights for this model.
1340 Weights can be given for different formats, but should otherwise be equivalent.
1341 The available weight formats determine which consumers can use this model."""
1343 @model_validator(mode="before")
1344 @classmethod
1345 def _convert_from_older_format(
1346 cls, data: BioimageioYamlContent, /
1347 ) -> BioimageioYamlContent:
1348 convert_from_older_format(data)
1349 return data
1351 def get_input_test_arrays(self) -> List[NDArray[Any]]:
1352 data = [load_array(ipt) for ipt in self.test_inputs]
1353 assert all(isinstance(d, np.ndarray) for d in data)
1354 return data
1356 def get_output_test_arrays(self) -> List[NDArray[Any]]:
1357 data = [load_array(out) for out in self.test_outputs]
1358 assert all(isinstance(d, np.ndarray) for d in data)
1359 return data