Coverage for bioimageio/spec/model/v0_4.py: 91%
595 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:34 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:34 +0000
1from __future__ import annotations
3import collections.abc
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8 ClassVar,
9 Dict,
10 List,
11 Literal,
12 Optional,
13 Sequence,
14 Tuple,
15 Type,
16 Union,
17 cast,
18)
20import numpy as np
21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf
22from numpy.typing import NDArray
23from pydantic import (
24 AllowInfNan,
25 Discriminator,
26 Field,
27 RootModel,
28 SerializationInfo,
29 SerializerFunctionWrapHandler,
30 StringConstraints,
31 TypeAdapter,
32 ValidationInfo,
33 WrapSerializer,
34 field_validator,
35 model_validator,
36)
37from typing_extensions import Annotated, Self, assert_never, get_args
39from .._internal.common_nodes import (
40 KwargsNode,
41 Node,
42 NodeWithExplicitlySetFields,
43)
44from .._internal.constants import SHA256_HINT
45from .._internal.field_validation import validate_unique_entries
46from .._internal.field_warning import issue_warning, warn
47from .._internal.io import BioimageioYamlContent, WithSuffix
48from .._internal.io import FileDescr as FileDescr
49from .._internal.io_basics import Sha256 as Sha256
50from .._internal.io_packaging import include_in_package
51from .._internal.io_utils import load_array
52from .._internal.packaging_context import packaging_context_var
53from .._internal.types import Datetime as Datetime
54from .._internal.types import FileSource_, LowerCaseIdentifier
55from .._internal.types import Identifier as Identifier
56from .._internal.types import LicenseId as LicenseId
57from .._internal.types import NotEmpty as NotEmpty
58from .._internal.url import HttpUrl as HttpUrl
59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode
60from .._internal.validator_annotations import AfterValidator, RestrictCharacters
61from .._internal.version_type import Version as Version
62from .._internal.warning_levels import ALERT, INFO
63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS
64from ..dataset.v0_2 import DatasetDescr as DatasetDescr
65from ..dataset.v0_2 import LinkedDataset as LinkedDataset
66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr
67from ..generic.v0_2 import Author as Author
68from ..generic.v0_2 import BadgeDescr as BadgeDescr
69from ..generic.v0_2 import CiteEntry as CiteEntry
70from ..generic.v0_2 import Doi as Doi
71from ..generic.v0_2 import GenericModelDescrBase
72from ..generic.v0_2 import LinkedResource as LinkedResource
73from ..generic.v0_2 import Maintainer as Maintainer
74from ..generic.v0_2 import OrcidId as OrcidId
75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath
76from ..generic.v0_2 import ResourceId as ResourceId
77from ..generic.v0_2 import Uploader as Uploader
78from ._v0_4_converter import convert_from_older_format
81class ModelId(ResourceId):
82 pass
85AxesStr = Annotated[
86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries)
87]
88AxesInCZYX = Annotated[
89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries)
90]
92PostprocessingName = Literal[
93 "binarize",
94 "clip",
95 "scale_linear",
96 "sigmoid",
97 "zero_mean_unit_variance",
98 "scale_range",
99 "scale_mean_variance",
100]
101PreprocessingName = Literal[
102 "binarize",
103 "clip",
104 "scale_linear",
105 "sigmoid",
106 "zero_mean_unit_variance",
107 "scale_range",
108]
111class TensorName(LowerCaseIdentifier):
112 pass
115class CallableFromDepencencyNode(Node):
116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier)
118 module_name: str
119 """The Python module that implements **callable_name**."""
121 @field_validator("module_name", mode="after")
122 def _check_submodules(cls, module_name: str) -> str:
123 for submod in module_name.split("."):
124 _ = cls._submodule_adapter.validate_python(submod)
126 return module_name
128 callable_name: Identifier
129 """The callable Python identifier implemented in module **module_name**."""
132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]):
133 _inner_node_class = CallableFromDepencencyNode
134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
135 Annotated[
136 str,
137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"),
138 ]
139 ]
141 @classmethod
142 def _get_data(cls, valid_string_data: str):
143 *mods, callname = valid_string_data.split(".")
144 return dict(module_name=".".join(mods), callable_name=callname)
146 @property
147 def module_name(self):
148 """The Python module that implements **callable_name**."""
149 return self._inner_node.module_name
151 @property
152 def callable_name(self):
153 """The callable Python identifier implemented in module **module_name**."""
154 return self._inner_node.callable_name
157class CallableFromFileNode(Node):
158 source_file: Annotated[
159 Union[RelativeFilePath, HttpUrl],
160 Field(union_mode="left_to_right"),
161 include_in_package,
162 ]
163 """The Python source file that implements **callable_name**."""
164 callable_name: Identifier
165 """The callable Python identifier implemented in **source_file**."""
168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]):
169 _inner_node_class = CallableFromFileNode
170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
171 Annotated[
172 str,
173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
174 ]
175 ]
177 @classmethod
178 def _get_data(cls, valid_string_data: str):
179 *file_parts, callname = valid_string_data.split(":")
180 return dict(source_file=":".join(file_parts), callable_name=callname)
182 @property
183 def source_file(self):
184 """The Python source file that implements **callable_name**."""
185 return self._inner_node.source_file
187 @property
188 def callable_name(self):
189 """The callable Python identifier implemented in **source_file**."""
190 return self._inner_node.callable_name
193CustomCallable = Annotated[
194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right")
195]
198class DependenciesNode(Node):
199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])]
200 """Dependency manager"""
202 file: FileSource_
203 """Dependency file"""
206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]):
207 _inner_node_class = DependenciesNode
208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
209 Annotated[
210 str,
211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
212 ]
213 ]
215 @classmethod
216 def _get_data(cls, valid_string_data: str):
217 manager, *file_parts = valid_string_data.split(":")
218 return dict(manager=manager, file=":".join(file_parts))
220 @property
221 def manager(self):
222 """Dependency manager"""
223 return self._inner_node.manager
225 @property
226 def file(self):
227 """Dependency file"""
228 return self._inner_node.file
231WeightsFormat = Literal[
232 "keras_hdf5",
233 "onnx",
234 "pytorch_state_dict",
235 "tensorflow_js",
236 "tensorflow_saved_model_bundle",
237 "torchscript",
238]
241class WeightsEntryDescrBase(FileDescr):
242 type: ClassVar[WeightsFormat]
243 weights_format_name: ClassVar[str] # human readable
245 source: FileSource_
246 """The weights file."""
248 attachments: Annotated[
249 Union[AttachmentsDescr, None],
250 warn(None, "Weights entry depends on additional attachments.", ALERT),
251 ] = None
252 """Attachments that are specific to this weights entry."""
254 authors: Union[List[Author], None] = None
255 """Authors
256 Either the person(s) that have trained this model resulting in the original weights file.
257 (If this is the initial weights entry, i.e. it does not have a `parent`)
258 Or the person(s) who have converted the weights to this weights format.
259 (If this is a child weight, i.e. it has a `parent` field)
260 """
262 dependencies: Annotated[
263 Optional[Dependencies],
264 warn(
265 None,
266 "Custom dependencies ({value}) specified. Avoid this whenever possible "
267 + "to allow execution in a wider range of software environments.",
268 ),
269 Field(
270 examples=[
271 "conda:environment.yaml",
272 "maven:./pom.xml",
273 "pip:./requirements.txt",
274 ]
275 ),
276 ] = None
277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`."""
279 parent: Annotated[
280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
281 ] = None
282 """The source weights these weights were converted from.
283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
285 All weight entries except one (the initial set of weights resulting from training the model),
286 need to have this field."""
288 @model_validator(mode="after")
289 def check_parent_is_not_self(self) -> Self:
290 if self.type == self.parent:
291 raise ValueError("Weights entry can't be it's own parent.")
293 return self
296class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
297 type = "keras_hdf5"
298 weights_format_name: ClassVar[str] = "Keras HDF5"
299 tensorflow_version: Optional[Version] = None
300 """TensorFlow version used to create these weights"""
302 @field_validator("tensorflow_version", mode="after")
303 @classmethod
304 def _tfv(cls, value: Any):
305 if value is None:
306 issue_warning(
307 "missing. Please specify the TensorFlow version"
308 + " these weights were created with.",
309 value=value,
310 severity=ALERT,
311 field="tensorflow_version",
312 )
313 return value
316class OnnxWeightsDescr(WeightsEntryDescrBase):
317 type = "onnx"
318 weights_format_name: ClassVar[str] = "ONNX"
319 opset_version: Optional[Annotated[int, Ge(7)]] = None
320 """ONNX opset version"""
322 @field_validator("opset_version", mode="after")
323 @classmethod
324 def _ov(cls, value: Any):
325 if value is None:
326 issue_warning(
327 "Missing ONNX opset version (aka ONNX opset number). "
328 + "Please specify the ONNX opset version these weights were created"
329 + " with.",
330 value=value,
331 severity=ALERT,
332 field="opset_version",
333 )
334 return value
337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
338 type = "pytorch_state_dict"
339 weights_format_name: ClassVar[str] = "Pytorch State Dict"
340 architecture: CustomCallable = Field(
341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"]
342 )
343 """callable returning a torch.nn.Module instance.
344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`.
345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`."""
347 architecture_sha256: Annotated[
348 Optional[Sha256],
349 Field(
350 description=(
351 "The SHA256 of the architecture source file, if the architecture is not"
352 " defined in a module listed in `dependencies`\n"
353 )
354 + SHA256_HINT,
355 ),
356 ] = None
357 """The SHA256 of the architecture source file,
358 if the architecture is not defined in a module listed in `dependencies`"""
360 @model_validator(mode="after")
361 def check_architecture_sha256(self) -> Self:
362 if isinstance(self.architecture, CallableFromFile):
363 if self.architecture_sha256 is None:
364 raise ValueError(
365 "Missing required `architecture_sha256` for `architecture` with"
366 + " source file."
367 )
368 elif self.architecture_sha256 is not None:
369 raise ValueError(
370 "Got `architecture_sha256` for architecture that does not have a source"
371 + " file."
372 )
374 return self
376 kwargs: Dict[str, Any] = Field(
377 default_factory=cast(Callable[[], Dict[str, Any]], dict)
378 )
379 """key word arguments for the `architecture` callable"""
381 pytorch_version: Optional[Version] = None
382 """Version of the PyTorch library used.
383 If `depencencies` is specified it should include pytorch and the verison has to match.
384 (`dependencies` overrules `pytorch_version`)"""
386 @field_validator("pytorch_version", mode="after")
387 @classmethod
388 def _ptv(cls, value: Any):
389 if value is None:
390 issue_warning(
391 "missing. Please specify the PyTorch version these"
392 + " PyTorch state dict weights were created with.",
393 value=value,
394 severity=ALERT,
395 field="pytorch_version",
396 )
397 return value
400class TorchscriptWeightsDescr(WeightsEntryDescrBase):
401 type = "torchscript"
402 weights_format_name: ClassVar[str] = "TorchScript"
403 pytorch_version: Optional[Version] = None
404 """Version of the PyTorch library used."""
406 @field_validator("pytorch_version", mode="after")
407 @classmethod
408 def _ptv(cls, value: Any):
409 if value is None:
410 issue_warning(
411 "missing. Please specify the PyTorch version these"
412 + " Torchscript weights were created with.",
413 value=value,
414 severity=ALERT,
415 field="pytorch_version",
416 )
417 return value
420class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
421 type = "tensorflow_js"
422 weights_format_name: ClassVar[str] = "Tensorflow.js"
423 tensorflow_version: Optional[Version] = None
424 """Version of the TensorFlow library used."""
426 @field_validator("tensorflow_version", mode="after")
427 @classmethod
428 def _tfv(cls, value: Any):
429 if value is None:
430 issue_warning(
431 "missing. Please specify the TensorFlow version"
432 + " these TensorflowJs weights were created with.",
433 value=value,
434 severity=ALERT,
435 field="tensorflow_version",
436 )
437 return value
439 source: FileSource_
440 """The multi-file weights.
441 All required files/folders should be a zip archive."""
444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
445 type = "tensorflow_saved_model_bundle"
446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
447 tensorflow_version: Optional[Version] = None
448 """Version of the TensorFlow library used."""
450 @field_validator("tensorflow_version", mode="after")
451 @classmethod
452 def _tfv(cls, value: Any):
453 if value is None:
454 issue_warning(
455 "missing. Please specify the TensorFlow version"
456 + " these Tensorflow saved model bundle weights were created with.",
457 value=value,
458 severity=ALERT,
459 field="tensorflow_version",
460 )
461 return value
464class WeightsDescr(Node):
465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
466 onnx: Optional[OnnxWeightsDescr] = None
467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
470 None
471 )
472 torchscript: Optional[TorchscriptWeightsDescr] = None
474 @model_validator(mode="after")
475 def check_one_entry(self) -> Self:
476 if all(
477 entry is None
478 for entry in [
479 self.keras_hdf5,
480 self.onnx,
481 self.pytorch_state_dict,
482 self.tensorflow_js,
483 self.tensorflow_saved_model_bundle,
484 self.torchscript,
485 ]
486 ):
487 raise ValueError("Missing weights entry")
489 return self
491 def __getitem__(
492 self,
493 key: WeightsFormat,
494 ):
495 if key == "keras_hdf5":
496 ret = self.keras_hdf5
497 elif key == "onnx":
498 ret = self.onnx
499 elif key == "pytorch_state_dict":
500 ret = self.pytorch_state_dict
501 elif key == "tensorflow_js":
502 ret = self.tensorflow_js
503 elif key == "tensorflow_saved_model_bundle":
504 ret = self.tensorflow_saved_model_bundle
505 elif key == "torchscript":
506 ret = self.torchscript
507 else:
508 raise KeyError(key)
510 if ret is None:
511 raise KeyError(key)
513 return ret
515 @property
516 def available_formats(self):
517 return {
518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
519 **({} if self.onnx is None else {"onnx": self.onnx}),
520 **(
521 {}
522 if self.pytorch_state_dict is None
523 else {"pytorch_state_dict": self.pytorch_state_dict}
524 ),
525 **(
526 {}
527 if self.tensorflow_js is None
528 else {"tensorflow_js": self.tensorflow_js}
529 ),
530 **(
531 {}
532 if self.tensorflow_saved_model_bundle is None
533 else {
534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
535 }
536 ),
537 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
538 }
540 @property
541 def missing_formats(self):
542 return {
543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
544 }
547class ParameterizedInputShape(Node):
548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
550 min: NotEmpty[List[int]]
551 """The minimum input shape"""
553 step: NotEmpty[List[int]]
554 """The minimum shape change"""
556 def __len__(self) -> int:
557 return len(self.min)
559 @model_validator(mode="after")
560 def matching_lengths(self) -> Self:
561 if len(self.min) != len(self.step):
562 raise ValueError("`min` and `step` required to have the same length")
564 return self
567class ImplicitOutputShape(Node):
568 """Output tensor shape depending on an input tensor shape.
569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`"""
571 reference_tensor: TensorName
572 """Name of the reference tensor."""
574 scale: NotEmpty[List[Optional[float]]]
575 """output_pix/input_pix for each dimension.
576 'null' values indicate new dimensions, whose length is defined by 2*`offset`"""
578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]]
579 """Position of origin wrt to input."""
581 def __len__(self) -> int:
582 return len(self.scale)
584 @model_validator(mode="after")
585 def matching_lengths(self) -> Self:
586 if len(self.scale) != len(self.offset):
587 raise ValueError(
588 f"scale {self.scale} has to have same length as offset {self.offset}!"
589 )
590 # if we have an expanded dimension, make sure that it's offet is not zero
591 for sc, off in zip(self.scale, self.offset):
592 if sc is None and not off:
593 raise ValueError("`offset` must not be zero if `scale` is none/zero")
595 return self
598class TensorDescrBase(Node):
599 name: TensorName
600 """Tensor name. No duplicates are allowed."""
602 description: str = ""
604 axes: AxesStr
605 """Axes identifying characters. Same length and order as the axes in `shape`.
606 | axis | description |
607 | --- | --- |
608 | b | batch (groups multiple samples) |
609 | i | instance/index/element |
610 | t | time |
611 | c | channel |
612 | z | spatial dimension z |
613 | y | spatial dimension y |
614 | x | spatial dimension x |
615 """
617 data_range: Optional[
618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]]
619 ] = None
620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
621 If not specified, the full data range that can be expressed in `data_type` is allowed."""
624class ProcessingKwargs(KwargsNode):
625 """base class for pre-/postprocessing key word arguments"""
628class ProcessingDescrBase(NodeWithExplicitlySetFields):
629 """processing base class"""
632class BinarizeKwargs(ProcessingKwargs):
633 """key word arguments for `BinarizeDescr`"""
635 threshold: float
636 """The fixed threshold"""
639class BinarizeDescr(ProcessingDescrBase):
640 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`.
641 Values above the threshold will be set to one, values below the threshold to zero.
642 """
644 implemented_name: ClassVar[Literal["binarize"]] = "binarize"
645 if TYPE_CHECKING:
646 name: Literal["binarize"] = "binarize"
647 else:
648 name: Literal["binarize"]
650 kwargs: BinarizeKwargs
653class ClipKwargs(ProcessingKwargs):
654 """key word arguments for `ClipDescr`"""
656 min: float
657 """minimum value for clipping"""
658 max: float
659 """maximum value for clipping"""
662class ClipDescr(ProcessingDescrBase):
663 """Clip tensor values to a range.
665 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min`
666 and above `ClipKwargs.max` to `ClipKwargs.max`.
667 """
669 implemented_name: ClassVar[Literal["clip"]] = "clip"
670 if TYPE_CHECKING:
671 name: Literal["clip"] = "clip"
672 else:
673 name: Literal["clip"]
675 kwargs: ClipKwargs
678class ScaleLinearKwargs(ProcessingKwargs):
679 """key word arguments for `ScaleLinearDescr`"""
681 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
682 """The subset of axes to scale jointly.
683 For example xy to scale the two image axes for 2d data jointly."""
685 gain: Union[float, List[float]] = 1.0
686 """multiplicative factor"""
688 offset: Union[float, List[float]] = 0.0
689 """additive term"""
691 @model_validator(mode="after")
692 def either_gain_or_offset(self) -> Self:
693 if (
694 self.gain == 1.0
695 or isinstance(self.gain, list)
696 and all(g == 1.0 for g in self.gain)
697 ) and (
698 self.offset == 0.0
699 or isinstance(self.offset, list)
700 and all(off == 0.0 for off in self.offset)
701 ):
702 raise ValueError(
703 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !="
704 + " 0.0."
705 )
707 return self
710class ScaleLinearDescr(ProcessingDescrBase):
711 """Fixed linear scaling."""
713 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear"
714 if TYPE_CHECKING:
715 name: Literal["scale_linear"] = "scale_linear"
716 else:
717 name: Literal["scale_linear"]
719 kwargs: ScaleLinearKwargs
722class SigmoidDescr(ProcessingDescrBase):
723 """The logistic sigmoid funciton, a.k.a. expit function."""
725 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid"
726 if TYPE_CHECKING:
727 name: Literal["sigmoid"] = "sigmoid"
728 else:
729 name: Literal["sigmoid"]
731 @property
732 def kwargs(self) -> ProcessingKwargs:
733 """empty kwargs"""
734 return ProcessingKwargs()
737class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
738 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
740 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed"
741 """Mode for computing mean and variance.
742 | mode | description |
743 | ----------- | ------------------------------------ |
744 | fixed | Fixed values for mean and variance |
745 | per_dataset | Compute for the entire dataset |
746 | per_sample | Compute for each sample individually |
747 """
748 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
749 """The subset of axes to normalize jointly.
750 For example `xy` to normalize the two image axes for 2d data jointly."""
752 mean: Annotated[
753 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)])
754 ] = None
755 """The mean value(s) to use for `mode: fixed`.
756 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`."""
757 # todo: check if means match input axes (for mode 'fixed')
759 std: Annotated[
760 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)])
761 ] = None
762 """The standard deviation values to use for `mode: fixed`. Analogous to mean."""
764 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
765 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
767 @model_validator(mode="after")
768 def mean_and_std_match_mode(self) -> Self:
769 if self.mode == "fixed" and (self.mean is None or self.std is None):
770 raise ValueError("`mean` and `std` are required for `mode: fixed`.")
771 elif self.mode != "fixed" and (self.mean is not None or self.std is not None):
772 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`")
774 return self
777class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
778 """Subtract mean and divide by variance."""
780 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = (
781 "zero_mean_unit_variance"
782 )
783 if TYPE_CHECKING:
784 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
785 else:
786 name: Literal["zero_mean_unit_variance"]
788 kwargs: ZeroMeanUnitVarianceKwargs
791class ScaleRangeKwargs(ProcessingKwargs):
792 """key word arguments for `ScaleRangeDescr`
794 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
795 this processing step normalizes data to the [0, 1] intervall.
796 For other percentiles the normalized values will partially be outside the [0, 1]
797 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
798 normalized values to a range.
799 """
801 mode: Literal["per_dataset", "per_sample"]
802 """Mode for computing percentiles.
803 | mode | description |
804 | ----------- | ------------------------------------ |
805 | per_dataset | compute for the entire dataset |
806 | per_sample | compute for each sample individually |
807 """
808 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
809 """The subset of axes to normalize jointly.
810 For example xy to normalize the two image axes for 2d data jointly."""
812 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0
813 """The lower percentile used to determine the value to align with zero."""
815 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0
816 """The upper percentile used to determine the value to align with one.
817 Has to be bigger than `min_percentile`.
818 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
819 accepting percentiles specified in the range 0.0 to 1.0."""
821 @model_validator(mode="after")
822 def min_smaller_max(self, info: ValidationInfo) -> Self:
823 if self.min_percentile >= self.max_percentile:
824 raise ValueError(
825 f"min_percentile {self.min_percentile} >= max_percentile"
826 + f" {self.max_percentile}"
827 )
829 return self
831 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
832 """Epsilon for numeric stability.
833 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
834 with `v_lower,v_upper` values at the respective percentiles."""
836 reference_tensor: Optional[TensorName] = None
837 """Tensor name to compute the percentiles from. Default: The tensor itself.
838 For any tensor in `inputs` only input tensor references are allowed.
839 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
842class ScaleRangeDescr(ProcessingDescrBase):
843 """Scale with percentiles."""
845 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range"
846 if TYPE_CHECKING:
847 name: Literal["scale_range"] = "scale_range"
848 else:
849 name: Literal["scale_range"]
851 kwargs: ScaleRangeKwargs
854class ScaleMeanVarianceKwargs(ProcessingKwargs):
855 """key word arguments for `ScaleMeanVarianceDescr`"""
857 mode: Literal["per_dataset", "per_sample"]
858 """Mode for computing mean and variance.
859 | mode | description |
860 | ----------- | ------------------------------------ |
861 | per_dataset | Compute for the entire dataset |
862 | per_sample | Compute for each sample individually |
863 """
865 reference_tensor: TensorName
866 """Name of tensor to match."""
868 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
869 """The subset of axes to scale jointly.
870 For example xy to normalize the two image axes for 2d data jointly.
871 Default: scale all non-batch axes jointly."""
873 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
874 """Epsilon for numeric stability:
875 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
878class ScaleMeanVarianceDescr(ProcessingDescrBase):
879 """Scale the tensor s.t. its mean and variance match a reference tensor."""
881 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
882 if TYPE_CHECKING:
883 name: Literal["scale_mean_variance"] = "scale_mean_variance"
884 else:
885 name: Literal["scale_mean_variance"]
887 kwargs: ScaleMeanVarianceKwargs
890PreprocessingDescr = Annotated[
891 Union[
892 BinarizeDescr,
893 ClipDescr,
894 ScaleLinearDescr,
895 SigmoidDescr,
896 ZeroMeanUnitVarianceDescr,
897 ScaleRangeDescr,
898 ],
899 Discriminator("name"),
900]
901PostprocessingDescr = Annotated[
902 Union[
903 BinarizeDescr,
904 ClipDescr,
905 ScaleLinearDescr,
906 SigmoidDescr,
907 ZeroMeanUnitVarianceDescr,
908 ScaleRangeDescr,
909 ScaleMeanVarianceDescr,
910 ],
911 Discriminator("name"),
912]
915class InputTensorDescr(TensorDescrBase):
916 data_type: Literal["float32", "uint8", "uint16"]
917 """For now an input tensor is expected to be given as `float32`.
918 The data flow in bioimage.io models is explained
919 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
921 shape: Annotated[
922 Union[Sequence[int], ParameterizedInputShape],
923 Field(
924 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))]
925 ),
926 ]
927 """Specification of input tensor shape."""
929 preprocessing: List[PreprocessingDescr] = Field(
930 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr]
931 Callable[[], List[PreprocessingDescr]], list
932 )
933 )
934 """Description of how this input should be preprocessed."""
936 @model_validator(mode="after")
937 def zero_batch_step_and_one_batch_size(self) -> Self:
938 bidx = self.axes.find("b")
939 if bidx == -1:
940 return self
942 if isinstance(self.shape, ParameterizedInputShape):
943 step = self.shape.step
944 shape = self.shape.min
945 if step[bidx] != 0:
946 raise ValueError(
947 "Input shape step has to be zero in the batch dimension (the batch"
948 + " dimension can always be increased, but `step` should specify how"
949 + " to increase the minimal shape to find the largest single batch"
950 + " shape)"
951 )
952 else:
953 shape = self.shape
955 if shape[bidx] != 1:
956 raise ValueError("Input shape has to be 1 in the batch dimension b.")
958 return self
960 @model_validator(mode="after")
961 def validate_preprocessing_kwargs(self) -> Self:
962 for p in self.preprocessing:
963 kwargs_axes = p.kwargs.get("axes")
964 if isinstance(kwargs_axes, str) and any(
965 a not in self.axes for a in kwargs_axes
966 ):
967 raise ValueError("`kwargs.axes` needs to be subset of `axes`")
969 return self
972class OutputTensorDescr(TensorDescrBase):
973 data_type: Literal[
974 "float32",
975 "float64",
976 "uint8",
977 "int8",
978 "uint16",
979 "int16",
980 "uint32",
981 "int32",
982 "uint64",
983 "int64",
984 "bool",
985 ]
986 """Data type.
987 The data flow in bioimage.io models is explained
988 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
990 shape: Union[Sequence[int], ImplicitOutputShape]
991 """Output tensor shape."""
993 halo: Optional[Sequence[int]] = None
994 """The `halo` that should be cropped from the output tensor to avoid boundary effects.
995 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`.
996 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead."""
998 postprocessing: List[PostprocessingDescr] = Field(
999 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1000 )
1001 """Description of how this output should be postprocessed."""
1003 @model_validator(mode="after")
1004 def matching_halo_length(self) -> Self:
1005 if self.halo and len(self.halo) != len(self.shape):
1006 raise ValueError(
1007 f"halo {self.halo} has to have same length as shape {self.shape}!"
1008 )
1010 return self
1012 @model_validator(mode="after")
1013 def validate_postprocessing_kwargs(self) -> Self:
1014 for p in self.postprocessing:
1015 kwargs_axes = p.kwargs.get("axes", "")
1016 if not isinstance(kwargs_axes, str):
1017 raise ValueError(f"Expected {kwargs_axes} to be a string")
1019 if any(a not in self.axes for a in kwargs_axes):
1020 raise ValueError("`kwargs.axes` needs to be subset of axes")
1022 return self
1025KnownRunMode = Literal["deepimagej"]
1028class RunMode(Node):
1029 name: Annotated[
1030 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.")
1031 ]
1032 """Run mode name"""
1034 kwargs: Dict[str, Any] = Field(
1035 default_factory=cast(Callable[[], Dict[str, Any]], dict)
1036 )
1037 """Run mode specific key word arguments"""
1040class LinkedModel(Node):
1041 """Reference to a bioimage.io model."""
1043 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])]
1044 """A valid model `id` from the bioimage.io collection."""
1046 version_number: Optional[int] = None
1047 """version number (n-th published version, not the semantic version) of linked model"""
1050def package_weights(
1051 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr]
1052 handler: SerializerFunctionWrapHandler,
1053 info: SerializationInfo,
1054):
1055 ctxt = packaging_context_var.get()
1056 if ctxt is not None and ctxt.weights_priority_order is not None:
1057 for wf in ctxt.weights_priority_order:
1058 w = getattr(value, wf, None)
1059 if w is not None:
1060 break
1061 else:
1062 raise ValueError(
1063 "None of the weight formats in `weights_priority_order`"
1064 + f" ({ctxt.weights_priority_order}) is present in the given model."
1065 )
1067 assert isinstance(w, Node), type(w)
1068 # construct WeightsDescr with new single weight format entry
1069 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"})
1070 value = value.model_construct(None, **{wf: new_w})
1072 return handler(
1073 value,
1074 info, # pyright: ignore[reportArgumentType] # taken from pydantic docs
1075 )
1078class ModelDescr(GenericModelDescrBase):
1079 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
1081 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
1082 """
1084 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10"
1085 if TYPE_CHECKING:
1086 format_version: Literal["0.4.10"] = "0.4.10"
1087 else:
1088 format_version: Literal["0.4.10"]
1089 """Version of the bioimage.io model description specification used.
1090 When creating a new model always use the latest micro/patch version described here.
1091 The `format_version` is important for any consumer software to understand how to parse the fields.
1092 """
1094 implemented_type: ClassVar[Literal["model"]] = "model"
1095 if TYPE_CHECKING:
1096 type: Literal["model"] = "model"
1097 else:
1098 type: Literal["model"]
1099 """Specialized resource type 'model'"""
1101 id: Optional[ModelId] = None
1102 """bioimage.io-wide unique resource identifier
1103 assigned by bioimage.io; version **un**specific."""
1105 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
1106 List[Author]
1107 ]
1108 """The authors are the creators of the model RDF and the primary points of contact."""
1110 documentation: Annotated[
1111 FileSource_,
1112 Field(
1113 examples=[
1114 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
1115 "README.md",
1116 ],
1117 ),
1118 ]
1119 """URL or relative path to a markdown file with additional documentation.
1120 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
1121 The documentation should include a '[#[#]]# Validation' (sub)section
1122 with details on how to quantitatively validate the model on unseen data."""
1124 inputs: NotEmpty[List[InputTensorDescr]]
1125 """Describes the input tensors expected by this model."""
1127 license: Annotated[
1128 Union[LicenseId, str],
1129 warn(LicenseId, "Unknown license id '{value}'."),
1130 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]),
1131 ]
1132 """A [SPDX license identifier](https://spdx.org/licenses/).
1133 We do notsupport custom license beyond the SPDX license list, if you need that please
1134 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose
1135 ) to discuss your intentions with the community."""
1137 name: Annotated[
1138 str,
1139 MinLen(1),
1140 warn(MinLen(5), "Name shorter than 5 characters.", INFO),
1141 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
1142 ]
1143 """A human-readable name of this model.
1144 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters."""
1146 outputs: NotEmpty[List[OutputTensorDescr]]
1147 """Describes the output tensors."""
1149 @field_validator("inputs", "outputs")
1150 @classmethod
1151 def unique_tensor_descr_names(
1152 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]]
1153 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]:
1154 unique_names = {str(v.name) for v in value}
1155 if len(unique_names) != len(value):
1156 raise ValueError("Duplicate tensor descriptor names")
1158 return value
1160 @model_validator(mode="after")
1161 def unique_io_names(self) -> Self:
1162 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s}
1163 if len(unique_names) != (len(self.inputs) + len(self.outputs)):
1164 raise ValueError("Duplicate tensor descriptor names across inputs/outputs")
1166 return self
1168 @model_validator(mode="after")
1169 def minimum_shape2valid_output(self) -> Self:
1170 tensors_by_name: Dict[
1171 TensorName, Union[InputTensorDescr, OutputTensorDescr]
1172 ] = {t.name: t for t in self.inputs + self.outputs}
1174 for out in self.outputs:
1175 if isinstance(out.shape, ImplicitOutputShape):
1176 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape)
1177 ndim_out_ref = len(
1178 [scale for scale in out.shape.scale if scale is not None]
1179 )
1180 if ndim_ref != ndim_out_ref:
1181 expanded_dim_note = (
1182 " Note that expanded dimensions (`scale`: null) are not"
1183 + f" counted for {out.name}'sdimensionality here."
1184 if None in out.shape.scale
1185 else ""
1186 )
1187 raise ValueError(
1188 f"Referenced tensor '{out.shape.reference_tensor}' with"
1189 + f" {ndim_ref} dimensions does not match output tensor"
1190 + f" '{out.name}' with"
1191 + f" {ndim_out_ref} dimensions.{expanded_dim_note}"
1192 )
1194 min_out_shape = self._get_min_shape(out, tensors_by_name)
1195 if out.halo:
1196 halo = out.halo
1197 halo_msg = f" for halo {out.halo}"
1198 else:
1199 halo = [0] * len(min_out_shape)
1200 halo_msg = ""
1202 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]):
1203 raise ValueError(
1204 f"Minimal shape {min_out_shape} of output {out.name} is too"
1205 + f" small{halo_msg}."
1206 )
1208 return self
1210 @classmethod
1211 def _get_min_shape(
1212 cls,
1213 t: Union[InputTensorDescr, OutputTensorDescr],
1214 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]],
1215 ) -> Sequence[int]:
1216 """output with subtracted halo has to result in meaningful output even for the minimal input
1217 see https://github.com/bioimage-io/spec-bioimage-io/issues/392
1218 """
1219 if isinstance(t.shape, collections.abc.Sequence):
1220 return t.shape
1221 elif isinstance(t.shape, ParameterizedInputShape):
1222 return t.shape.min
1223 elif isinstance(t.shape, ImplicitOutputShape):
1224 pass
1225 else:
1226 assert_never(t.shape)
1228 ref_shape = cls._get_min_shape(
1229 tensors_by_name[t.shape.reference_tensor], tensors_by_name
1230 )
1232 if None not in t.shape.scale:
1233 scale: Sequence[float, ...] = t.shape.scale # type: ignore
1234 else:
1235 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None]
1236 new_ref_shape: List[int] = []
1237 for idx in range(len(t.shape.scale)):
1238 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims)
1239 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx])
1241 ref_shape = new_ref_shape
1242 assert len(ref_shape) == len(t.shape.scale)
1243 scale = [0.0 if sc is None else sc for sc in t.shape.scale]
1245 offset = t.shape.offset
1246 assert len(offset) == len(scale)
1247 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)]
1249 @model_validator(mode="after")
1250 def validate_tensor_references_in_inputs(self) -> Self:
1251 for t in self.inputs:
1252 for proc in t.preprocessing:
1253 if "reference_tensor" not in proc.kwargs:
1254 continue
1256 ref_tensor = proc.kwargs["reference_tensor"]
1257 if ref_tensor is not None and str(ref_tensor) not in {
1258 str(t.name) for t in self.inputs
1259 }:
1260 raise ValueError(f"'{ref_tensor}' not found in inputs")
1262 if ref_tensor == t.name:
1263 raise ValueError(
1264 f"invalid self reference for preprocessing of tensor {t.name}"
1265 )
1267 return self
1269 @model_validator(mode="after")
1270 def validate_tensor_references_in_outputs(self) -> Self:
1271 for t in self.outputs:
1272 for proc in t.postprocessing:
1273 if "reference_tensor" not in proc.kwargs:
1274 continue
1275 ref_tensor = proc.kwargs["reference_tensor"]
1276 if ref_tensor is not None and str(ref_tensor) not in {
1277 str(t.name) for t in self.inputs
1278 }:
1279 raise ValueError(f"{ref_tensor} not found in inputs")
1281 return self
1283 packaged_by: List[Author] = Field(
1284 default_factory=cast(Callable[[], List[Author]], list)
1285 )
1286 """The persons that have packaged and uploaded this model.
1287 Only required if those persons differ from the `authors`."""
1289 parent: Optional[LinkedModel] = None
1290 """The model from which this model is derived, e.g. by fine-tuning the weights."""
1292 @field_validator("parent", mode="before")
1293 @classmethod
1294 def ignore_url_parent(cls, parent: Any):
1295 if isinstance(parent, dict):
1296 return None
1298 else:
1299 return parent
1301 run_mode: Optional[RunMode] = None
1302 """Custom run mode for this model: for more complex prediction procedures like test time
1303 data augmentation that currently cannot be expressed in the specification.
1304 No standard run modes are defined yet."""
1306 sample_inputs: List[FileSource_] = Field(
1307 default_factory=cast(Callable[[], List[FileSource_]], list)
1308 )
1309 """URLs/relative paths to sample inputs to illustrate possible inputs for the model,
1310 for example stored as PNG or TIFF images.
1311 The sample files primarily serve to inform a human user about an example use case"""
1313 sample_outputs: List[FileSource_] = Field(
1314 default_factory=cast(Callable[[], List[FileSource_]], list)
1315 )
1316 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`."""
1318 test_inputs: NotEmpty[
1319 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1320 ]
1321 """Test input tensors compatible with the `inputs` description for a **single test case**.
1322 This means if your model has more than one input, you should provide one URL/relative path for each input.
1323 Each test input should be a file with an ndarray in
1324 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1325 The extension must be '.npy'."""
1327 test_outputs: NotEmpty[
1328 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1329 ]
1330 """Analog to `test_inputs`."""
1332 timestamp: Datetime
1333 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
1334 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat)."""
1336 training_data: Union[LinkedDataset, DatasetDescr, None] = None
1337 """The dataset used to train this model"""
1339 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
1340 """The weights for this model.
1341 Weights can be given for different formats, but should otherwise be equivalent.
1342 The available weight formats determine which consumers can use this model."""
1344 @model_validator(mode="before")
1345 @classmethod
1346 def _convert_from_older_format(
1347 cls, data: BioimageioYamlContent, /
1348 ) -> BioimageioYamlContent:
1349 convert_from_older_format(data)
1350 return data
1352 def get_input_test_arrays(self) -> List[NDArray[Any]]:
1353 data = [load_array(ipt) for ipt in self.test_inputs]
1354 assert all(isinstance(d, np.ndarray) for d in data)
1355 return data
1357 def get_output_test_arrays(self) -> List[NDArray[Any]]:
1358 data = [load_array(out) for out in self.test_outputs]
1359 assert all(isinstance(d, np.ndarray) for d in data)
1360 return data