Coverage for src / bioimageio / spec / model / v0_4.py: 91%
593 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-08 13:52 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-08 13:52 +0000
1from __future__ import annotations
3import collections.abc
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Callable,
8 ClassVar,
9 Dict,
10 List,
11 Literal,
12 Optional,
13 Sequence,
14 Tuple,
15 Type,
16 Union,
17 cast,
18)
20import numpy as np
21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf
22from numpy.typing import NDArray
23from pydantic import (
24 AllowInfNan,
25 Discriminator,
26 Field,
27 RootModel,
28 SerializationInfo,
29 SerializerFunctionWrapHandler,
30 StringConstraints,
31 TypeAdapter,
32 ValidationInfo,
33 WrapSerializer,
34 field_validator,
35 model_validator,
36)
37from typing_extensions import Annotated, Self, assert_never, get_args
39from .._internal.common_nodes import (
40 KwargsNode,
41 Node,
42 NodeWithExplicitlySetFields,
43)
44from .._internal.constants import SHA256_HINT
45from .._internal.field_validation import validate_unique_entries
46from .._internal.field_warning import issue_warning, warn
47from .._internal.io import BioimageioYamlContent, WithSuffix
48from .._internal.io import FileDescr as FileDescr
49from .._internal.io_basics import Sha256 as Sha256
50from .._internal.io_packaging import include_in_package
51from .._internal.io_utils import load_array
52from .._internal.packaging_context import packaging_context_var
53from .._internal.types import Datetime as Datetime
54from .._internal.types import FileSource_, LowerCaseIdentifier
55from .._internal.types import Identifier as Identifier
56from .._internal.types import LicenseId as LicenseId
57from .._internal.types import NotEmpty as NotEmpty
58from .._internal.url import HttpUrl as HttpUrl
59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode
60from .._internal.validator_annotations import AfterValidator, RestrictCharacters
61from .._internal.version_type import Version as Version
62from .._internal.warning_levels import ALERT, INFO
63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS
64from ..dataset.v0_2 import DatasetDescr as DatasetDescr
65from ..dataset.v0_2 import LinkedDataset as LinkedDataset
66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr
67from ..generic.v0_2 import Author as Author
68from ..generic.v0_2 import BadgeDescr as BadgeDescr
69from ..generic.v0_2 import CiteEntry as CiteEntry
70from ..generic.v0_2 import Doi as Doi
71from ..generic.v0_2 import GenericModelDescrBase
72from ..generic.v0_2 import LinkedResource as LinkedResource
73from ..generic.v0_2 import Maintainer as Maintainer
74from ..generic.v0_2 import OrcidId as OrcidId
75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath
76from ..generic.v0_2 import ResourceId as ResourceId
77from ..generic.v0_2 import Uploader as Uploader
78from ._v0_4_converter import convert_from_older_format
81class ModelId(ResourceId):
82 pass
85AxesStr = Annotated[
86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries)
87]
88AxesInCZYX = Annotated[
89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries)
90]
92PostprocessingName = Literal[
93 "binarize",
94 "clip",
95 "scale_linear",
96 "sigmoid",
97 "zero_mean_unit_variance",
98 "scale_range",
99 "scale_mean_variance",
100]
101PreprocessingName = Literal[
102 "binarize",
103 "clip",
104 "scale_linear",
105 "sigmoid",
106 "zero_mean_unit_variance",
107 "scale_range",
108]
111class TensorName(LowerCaseIdentifier):
112 pass
115class CallableFromDepencencyNode(Node):
116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier)
118 module_name: str
119 """The Python module that implements **callable_name**."""
121 @field_validator("module_name", mode="after")
122 def _check_submodules(cls, module_name: str) -> str:
123 for submod in module_name.split("."):
124 _ = cls._submodule_adapter.validate_python(submod)
126 return module_name
128 callable_name: Identifier
129 """The callable Python identifier implemented in module **module_name**."""
132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]):
133 _inner_node_class = CallableFromDepencencyNode
134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
135 Annotated[
136 str,
137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"),
138 ]
139 ]
141 @classmethod
142 def _get_data(cls, valid_string_data: str):
143 *mods, callname = valid_string_data.split(".")
144 return dict(module_name=".".join(mods), callable_name=callname)
146 @property
147 def module_name(self):
148 """The Python module that implements **callable_name**."""
149 return self._inner_node.module_name
151 @property
152 def callable_name(self):
153 """The callable Python identifier implemented in module **module_name**."""
154 return self._inner_node.callable_name
157class CallableFromFileNode(Node):
158 source_file: Annotated[
159 Union[RelativeFilePath, HttpUrl],
160 Field(union_mode="left_to_right"),
161 include_in_package,
162 ]
163 """The Python source file that implements **callable_name**."""
164 callable_name: Identifier
165 """The callable Python identifier implemented in **source_file**."""
168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]):
169 _inner_node_class = CallableFromFileNode
170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
171 Annotated[
172 str,
173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
174 ]
175 ]
177 @classmethod
178 def _get_data(cls, valid_string_data: str):
179 *file_parts, callname = valid_string_data.split(":")
180 return dict(source_file=":".join(file_parts), callable_name=callname)
182 @property
183 def source_file(self):
184 """The Python source file that implements **callable_name**."""
185 return self._inner_node.source_file
187 @property
188 def callable_name(self):
189 """The callable Python identifier implemented in **source_file**."""
190 return self._inner_node.callable_name
193CustomCallable = Annotated[
194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right")
195]
198class DependenciesNode(Node):
199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])]
200 """Dependency manager"""
202 file: FileSource_
203 """Dependency file"""
206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]):
207 _inner_node_class = DependenciesNode
208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
209 Annotated[
210 str,
211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"),
212 ]
213 ]
215 @classmethod
216 def _get_data(cls, valid_string_data: str):
217 manager, *file_parts = valid_string_data.split(":")
218 return dict(manager=manager, file=":".join(file_parts))
220 @property
221 def manager(self):
222 """Dependency manager"""
223 return self._inner_node.manager
225 @property
226 def file(self):
227 """Dependency file"""
228 return self._inner_node.file
231WeightsFormat = Literal[
232 "keras_hdf5",
233 "onnx",
234 "pytorch_state_dict",
235 "tensorflow_js",
236 "tensorflow_saved_model_bundle",
237 "torchscript",
238]
241class WeightsEntryDescrBase(FileDescr):
242 type: ClassVar[WeightsFormat]
243 weights_format_name: ClassVar[str] # human readable
245 source: FileSource_
246 """The weights file."""
248 attachments: Annotated[
249 Union[AttachmentsDescr, None],
250 warn(None, "Weights entry depends on additional attachments.", ALERT),
251 ] = None
252 """Attachments that are specific to this weights entry."""
254 authors: Union[List[Author], None] = None
255 """Authors
256 Either the person(s) that have trained this model resulting in the original weights file.
257 (If this is the initial weights entry, i.e. it does not have a `parent`)
258 Or the person(s) who have converted the weights to this weights format.
259 (If this is a child weight, i.e. it has a `parent` field)
260 """
262 dependencies: Annotated[
263 Optional[Dependencies],
264 warn(
265 None,
266 "Custom dependencies ({value}) specified. Avoid this whenever possible "
267 + "to allow execution in a wider range of software environments.",
268 ),
269 Field(
270 examples=[
271 "conda:environment.yaml",
272 "maven:./pom.xml",
273 "pip:./requirements.txt",
274 ]
275 ),
276 ] = None
277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`."""
279 parent: Annotated[
280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
281 ] = None
282 """The source weights these weights were converted from.
283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
285 All weight entries except one (the initial set of weights resulting from training the model),
286 need to have this field."""
288 @model_validator(mode="after")
289 def check_parent_is_not_self(self) -> Self:
290 if self.type == self.parent:
291 raise ValueError("Weights entry can't be it's own parent.")
293 return self
296class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
297 type: ClassVar[WeightsFormat] = "keras_hdf5"
298 weights_format_name: ClassVar[str] = "Keras HDF5"
299 tensorflow_version: Optional[Version] = None
300 """TensorFlow version used to create these weights"""
302 @field_validator("tensorflow_version", mode="after")
303 @classmethod
304 def _tfv(cls, value: Any):
305 if value is None:
306 issue_warning(
307 "missing. Please specify the TensorFlow version"
308 + " these weights were created with.",
309 value=value,
310 severity=ALERT,
311 field="tensorflow_version",
312 )
313 return value
316class OnnxWeightsDescr(WeightsEntryDescrBase):
317 type: ClassVar[WeightsFormat] = "onnx"
318 weights_format_name: ClassVar[str] = "ONNX"
319 opset_version: Optional[Annotated[int, Ge(7)]] = None
320 """ONNX opset version"""
322 @field_validator("opset_version", mode="after")
323 @classmethod
324 def _ov(cls, value: Any):
325 if value is None:
326 issue_warning(
327 "Missing ONNX opset version (aka ONNX opset number). "
328 + "Please specify the ONNX opset version these weights were created"
329 + " with.",
330 value=value,
331 severity=ALERT,
332 field="opset_version",
333 )
334 return value
337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
338 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
339 weights_format_name: ClassVar[str] = "Pytorch State Dict"
340 architecture: CustomCallable = Field(
341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"]
342 )
343 """callable returning a torch.nn.Module instance.
344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`.
345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`."""
347 architecture_sha256: Annotated[
348 Optional[Sha256],
349 Field(
350 description=(
351 "The SHA256 of the architecture source file, if the architecture is not"
352 " defined in a module listed in `dependencies`\n"
353 )
354 + SHA256_HINT,
355 ),
356 ] = None
357 """The SHA256 of the architecture source file,
358 if the architecture is not defined in a module listed in `dependencies`"""
360 @model_validator(mode="after")
361 def check_architecture_sha256(self) -> Self:
362 if isinstance(self.architecture, CallableFromFile):
363 if self.architecture_sha256 is None:
364 raise ValueError(
365 "Missing required `architecture_sha256` for `architecture` with"
366 + " source file."
367 )
368 elif self.architecture_sha256 is not None:
369 raise ValueError(
370 "Got `architecture_sha256` for architecture that does not have a source"
371 + " file."
372 )
374 return self
376 kwargs: Dict[str, Any] = Field(
377 default_factory=cast(Callable[[], Dict[str, Any]], dict)
378 )
379 """key word arguments for the `architecture` callable"""
381 pytorch_version: Optional[Version] = None
382 """Version of the PyTorch library used.
383 If `depencencies` is specified it should include pytorch and the verison has to match.
384 (`dependencies` overrules `pytorch_version`)"""
386 @field_validator("pytorch_version", mode="after")
387 @classmethod
388 def _ptv(cls, value: Any):
389 if value is None:
390 issue_warning(
391 "missing. Please specify the PyTorch version these"
392 + " PyTorch state dict weights were created with.",
393 value=value,
394 severity=ALERT,
395 field="pytorch_version",
396 )
397 return value
400class TorchscriptWeightsDescr(WeightsEntryDescrBase):
401 type: ClassVar[WeightsFormat] = "torchscript"
402 weights_format_name: ClassVar[str] = "TorchScript"
403 pytorch_version: Optional[Version] = None
404 """Version of the PyTorch library used."""
406 @field_validator("pytorch_version", mode="after")
407 @classmethod
408 def _ptv(cls, value: Any):
409 if value is None:
410 issue_warning(
411 "missing. Please specify the PyTorch version these"
412 + " Torchscript weights were created with.",
413 value=value,
414 severity=ALERT,
415 field="pytorch_version",
416 )
417 return value
420class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
421 type: ClassVar[WeightsFormat] = "tensorflow_js"
422 weights_format_name: ClassVar[str] = "Tensorflow.js"
423 tensorflow_version: Optional[Version] = None
424 """Version of the TensorFlow library used."""
426 @field_validator("tensorflow_version", mode="after")
427 @classmethod
428 def _tfv(cls, value: Any):
429 if value is None:
430 issue_warning(
431 "missing. Please specify the TensorFlow version"
432 + " these TensorflowJs weights were created with.",
433 value=value,
434 severity=ALERT,
435 field="tensorflow_version",
436 )
437 return value
439 source: FileSource_
440 """The multi-file weights.
441 All required files/folders should be a zip archive."""
444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
445 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
447 tensorflow_version: Optional[Version] = None
448 """Version of the TensorFlow library used."""
450 @field_validator("tensorflow_version", mode="after")
451 @classmethod
452 def _tfv(cls, value: Any):
453 if value is None:
454 issue_warning(
455 "missing. Please specify the TensorFlow version"
456 + " these Tensorflow saved model bundle weights were created with.",
457 value=value,
458 severity=ALERT,
459 field="tensorflow_version",
460 )
461 return value
464class WeightsDescr(Node):
465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
466 onnx: Optional[OnnxWeightsDescr] = None
467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
470 None
471 )
472 torchscript: Optional[TorchscriptWeightsDescr] = None
474 @model_validator(mode="after")
475 def check_one_entry(self) -> Self:
476 if all(
477 entry is None
478 for entry in [
479 self.keras_hdf5,
480 self.onnx,
481 self.pytorch_state_dict,
482 self.tensorflow_js,
483 self.tensorflow_saved_model_bundle,
484 self.torchscript,
485 ]
486 ):
487 raise ValueError("Missing weights entry")
489 return self
491 def __getitem__(
492 self,
493 key: WeightsFormat,
494 ):
495 if key == "keras_hdf5":
496 ret = self.keras_hdf5
497 elif key == "onnx":
498 ret = self.onnx
499 elif key == "pytorch_state_dict":
500 ret = self.pytorch_state_dict
501 elif key == "tensorflow_js":
502 ret = self.tensorflow_js
503 elif key == "tensorflow_saved_model_bundle":
504 ret = self.tensorflow_saved_model_bundle
505 elif key == "torchscript":
506 ret = self.torchscript
507 else:
508 raise KeyError(key)
510 if ret is None:
511 raise KeyError(key)
513 return ret
515 @property
516 def available_formats(self):
517 return {
518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
519 **({} if self.onnx is None else {"onnx": self.onnx}),
520 **(
521 {}
522 if self.pytorch_state_dict is None
523 else {"pytorch_state_dict": self.pytorch_state_dict}
524 ),
525 **(
526 {}
527 if self.tensorflow_js is None
528 else {"tensorflow_js": self.tensorflow_js}
529 ),
530 **(
531 {}
532 if self.tensorflow_saved_model_bundle is None
533 else {
534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
535 }
536 ),
537 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
538 }
540 @property
541 def missing_formats(self):
542 return {
543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
544 }
547class ParameterizedInputShape(Node):
548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
550 min: NotEmpty[List[int]]
551 """The minimum input shape"""
553 step: NotEmpty[List[int]]
554 """The minimum shape change"""
556 def __len__(self) -> int:
557 return len(self.min)
559 @model_validator(mode="after")
560 def matching_lengths(self) -> Self:
561 if len(self.min) != len(self.step):
562 raise ValueError("`min` and `step` required to have the same length")
564 return self
567class ImplicitOutputShape(Node):
568 """Output tensor shape depending on an input tensor shape.
569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`"""
571 reference_tensor: TensorName
572 """Name of the reference tensor."""
574 scale: NotEmpty[List[Optional[float]]]
575 """output_pix/input_pix for each dimension.
576 'null' values indicate new dimensions, whose length is defined by 2*`offset`"""
578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]]
579 """Position of origin wrt to input."""
581 def __len__(self) -> int:
582 return len(self.scale)
584 @model_validator(mode="after")
585 def matching_lengths(self) -> Self:
586 if len(self.scale) != len(self.offset):
587 raise ValueError(
588 f"scale {self.scale} has to have same length as offset {self.offset}!"
589 )
590 # if we have an expanded dimension, make sure that it's offet is not zero
591 for sc, off in zip(self.scale, self.offset):
592 if sc is None and not off:
593 raise ValueError("`offset` must not be zero if `scale` is none/zero")
595 return self
598class TensorDescrBase(Node):
599 name: TensorName
600 """Tensor name. No duplicates are allowed."""
602 description: str = ""
604 axes: AxesStr
605 """Axes identifying characters. Same length and order as the axes in `shape`.
606 | axis | description |
607 | --- | --- |
608 | b | batch (groups multiple samples) |
609 | i | instance/index/element |
610 | t | time |
611 | c | channel |
612 | z | spatial dimension z |
613 | y | spatial dimension y |
614 | x | spatial dimension x |
615 """
617 data_range: Optional[
618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]]
619 ] = None
620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
621 If not specified, the full data range that can be expressed in `data_type` is allowed."""
624class BinarizeKwargs(KwargsNode):
625 """key word arguments for `BinarizeDescr`"""
627 threshold: float
628 """The fixed threshold"""
631class BinarizeDescr(NodeWithExplicitlySetFields):
632 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`.
633 Values above the threshold will be set to one, values below the threshold to zero.
634 """
636 implemented_name: ClassVar[Literal["binarize"]] = "binarize"
637 if TYPE_CHECKING:
638 name: Literal["binarize"] = "binarize"
639 else:
640 name: Literal["binarize"]
642 kwargs: BinarizeKwargs
645class ClipKwargs(KwargsNode):
646 """key word arguments for `ClipDescr`"""
648 min: float
649 """minimum value for clipping"""
650 max: float
651 """maximum value for clipping"""
654class ClipDescr(NodeWithExplicitlySetFields):
655 """Clip tensor values to a range.
657 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min`
658 and above `ClipKwargs.max` to `ClipKwargs.max`.
659 """
661 implemented_name: ClassVar[Literal["clip"]] = "clip"
662 if TYPE_CHECKING:
663 name: Literal["clip"] = "clip"
664 else:
665 name: Literal["clip"]
667 kwargs: ClipKwargs
670class ScaleLinearKwargs(KwargsNode):
671 """key word arguments for `ScaleLinearDescr`"""
673 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
674 """The subset of axes to scale jointly.
675 For example xy to scale the two image axes for 2d data jointly."""
677 gain: Union[float, List[float]] = 1.0
678 """multiplicative factor"""
680 offset: Union[float, List[float]] = 0.0
681 """additive term"""
683 @model_validator(mode="after")
684 def either_gain_or_offset(self) -> Self:
685 if (
686 self.gain == 1.0
687 or isinstance(self.gain, list)
688 and all(g == 1.0 for g in self.gain)
689 ) and (
690 self.offset == 0.0
691 or isinstance(self.offset, list)
692 and all(off == 0.0 for off in self.offset)
693 ):
694 raise ValueError(
695 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !="
696 + " 0.0."
697 )
699 return self
702class ScaleLinearDescr(NodeWithExplicitlySetFields):
703 """Fixed linear scaling."""
705 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear"
706 if TYPE_CHECKING:
707 name: Literal["scale_linear"] = "scale_linear"
708 else:
709 name: Literal["scale_linear"]
711 kwargs: ScaleLinearKwargs
714class SigmoidDescr(NodeWithExplicitlySetFields):
715 """The logistic sigmoid funciton, a.k.a. expit function."""
717 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid"
718 if TYPE_CHECKING:
719 name: Literal["sigmoid"] = "sigmoid"
720 else:
721 name: Literal["sigmoid"]
723 @property
724 def kwargs(self) -> KwargsNode:
725 """empty kwargs"""
726 return KwargsNode()
729class ZeroMeanUnitVarianceKwargs(KwargsNode):
730 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
732 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed"
733 """Mode for computing mean and variance.
734 | mode | description |
735 | ----------- | ------------------------------------ |
736 | fixed | Fixed values for mean and variance |
737 | per_dataset | Compute for the entire dataset |
738 | per_sample | Compute for each sample individually |
739 """
740 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
741 """The subset of axes to normalize jointly.
742 For example `xy` to normalize the two image axes for 2d data jointly."""
744 mean: Annotated[
745 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)])
746 ] = None
747 """The mean value(s) to use for `mode: fixed`.
748 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`."""
749 # todo: check if means match input axes (for mode 'fixed')
751 std: Annotated[
752 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)])
753 ] = None
754 """The standard deviation values to use for `mode: fixed`. Analogous to mean."""
756 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
757 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
759 @model_validator(mode="after")
760 def mean_and_std_match_mode(self) -> Self:
761 if self.mode == "fixed" and (self.mean is None or self.std is None):
762 raise ValueError("`mean` and `std` are required for `mode: fixed`.")
763 elif self.mode != "fixed" and (self.mean is not None or self.std is not None):
764 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`")
766 return self
769class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
770 """Subtract mean and divide by variance."""
772 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = (
773 "zero_mean_unit_variance"
774 )
775 if TYPE_CHECKING:
776 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
777 else:
778 name: Literal["zero_mean_unit_variance"]
780 kwargs: ZeroMeanUnitVarianceKwargs
783class ScaleRangeKwargs(KwargsNode):
784 """key word arguments for `ScaleRangeDescr`
786 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
787 this processing step normalizes data to the [0, 1] intervall.
788 For other percentiles the normalized values will partially be outside the [0, 1]
789 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
790 normalized values to a range.
791 """
793 mode: Literal["per_dataset", "per_sample"]
794 """Mode for computing percentiles.
795 | mode | description |
796 | ----------- | ------------------------------------ |
797 | per_dataset | compute for the entire dataset |
798 | per_sample | compute for each sample individually |
799 """
800 axes: Annotated[AxesInCZYX, Field(examples=["xy"])]
801 """The subset of axes to normalize jointly.
802 For example xy to normalize the two image axes for 2d data jointly."""
804 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0
805 """The lower percentile used to determine the value to align with zero."""
807 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0
808 """The upper percentile used to determine the value to align with one.
809 Has to be bigger than `min_percentile`.
810 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
811 accepting percentiles specified in the range 0.0 to 1.0."""
813 @model_validator(mode="after")
814 def min_smaller_max(self, info: ValidationInfo) -> Self:
815 if self.min_percentile >= self.max_percentile:
816 raise ValueError(
817 f"min_percentile {self.min_percentile} >= max_percentile"
818 + f" {self.max_percentile}"
819 )
821 return self
823 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
824 """Epsilon for numeric stability.
825 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
826 with `v_lower,v_upper` values at the respective percentiles."""
828 reference_tensor: Optional[TensorName] = None
829 """Tensor name to compute the percentiles from. Default: The tensor itself.
830 For any tensor in `inputs` only input tensor references are allowed.
831 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`"""
834class ScaleRangeDescr(NodeWithExplicitlySetFields):
835 """Scale with percentiles."""
837 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range"
838 if TYPE_CHECKING:
839 name: Literal["scale_range"] = "scale_range"
840 else:
841 name: Literal["scale_range"]
843 kwargs: ScaleRangeKwargs
846class ScaleMeanVarianceKwargs(KwargsNode):
847 """key word arguments for `ScaleMeanVarianceDescr`"""
849 mode: Literal["per_dataset", "per_sample"]
850 """Mode for computing mean and variance.
851 | mode | description |
852 | ----------- | ------------------------------------ |
853 | per_dataset | Compute for the entire dataset |
854 | per_sample | Compute for each sample individually |
855 """
857 reference_tensor: TensorName
858 """Name of tensor to match."""
860 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None
861 """The subset of axes to scale jointly.
862 For example xy to normalize the two image axes for 2d data jointly.
863 Default: scale all non-batch axes jointly."""
865 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
866 """Epsilon for numeric stability:
867 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean."""
870class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields):
871 """Scale the tensor s.t. its mean and variance match a reference tensor."""
873 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
874 if TYPE_CHECKING:
875 name: Literal["scale_mean_variance"] = "scale_mean_variance"
876 else:
877 name: Literal["scale_mean_variance"]
879 kwargs: ScaleMeanVarianceKwargs
882PreprocessingDescr = Annotated[
883 Union[
884 BinarizeDescr,
885 ClipDescr,
886 ScaleLinearDescr,
887 SigmoidDescr,
888 ZeroMeanUnitVarianceDescr,
889 ScaleRangeDescr,
890 ],
891 Discriminator("name"),
892]
893PostprocessingDescr = Annotated[
894 Union[
895 BinarizeDescr,
896 ClipDescr,
897 ScaleLinearDescr,
898 SigmoidDescr,
899 ZeroMeanUnitVarianceDescr,
900 ScaleRangeDescr,
901 ScaleMeanVarianceDescr,
902 ],
903 Discriminator("name"),
904]
907class InputTensorDescr(TensorDescrBase):
908 data_type: Literal["float32", "uint8", "uint16"]
909 """For now an input tensor is expected to be given as `float32`.
910 The data flow in bioimage.io models is explained
911 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
913 shape: Annotated[
914 Union[Sequence[int], ParameterizedInputShape],
915 Field(
916 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))]
917 ),
918 ]
919 """Specification of input tensor shape."""
921 preprocessing: List[PreprocessingDescr] = Field(
922 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr]
923 Callable[[], List[PreprocessingDescr]], list
924 )
925 )
926 """Description of how this input should be preprocessed."""
928 @model_validator(mode="after")
929 def zero_batch_step_and_one_batch_size(self) -> Self:
930 bidx = self.axes.find("b")
931 if bidx == -1:
932 return self
934 if isinstance(self.shape, ParameterizedInputShape):
935 step = self.shape.step
936 shape = self.shape.min
937 if step[bidx] != 0:
938 raise ValueError(
939 "Input shape step has to be zero in the batch dimension (the batch"
940 + " dimension can always be increased, but `step` should specify how"
941 + " to increase the minimal shape to find the largest single batch"
942 + " shape)"
943 )
944 else:
945 shape = self.shape
947 if shape[bidx] != 1:
948 raise ValueError("Input shape has to be 1 in the batch dimension b.")
950 return self
952 @model_validator(mode="after")
953 def validate_preprocessing_kwargs(self) -> Self:
954 for p in self.preprocessing:
955 kwargs_axes = p.kwargs.get("axes")
956 if isinstance(kwargs_axes, str) and any(
957 a not in self.axes for a in kwargs_axes
958 ):
959 raise ValueError("`kwargs.axes` needs to be subset of `axes`")
961 return self
964class OutputTensorDescr(TensorDescrBase):
965 data_type: Literal[
966 "float32",
967 "float64",
968 "uint8",
969 "int8",
970 "uint16",
971 "int16",
972 "uint32",
973 "int32",
974 "uint64",
975 "int64",
976 "bool",
977 ]
978 """Data type.
979 The data flow in bioimage.io models is explained
980 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit)."""
982 shape: Union[Sequence[int], ImplicitOutputShape]
983 """Output tensor shape."""
985 halo: Optional[Sequence[int]] = None
986 """The `halo` that should be cropped from the output tensor to avoid boundary effects.
987 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`.
988 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead."""
990 postprocessing: List[PostprocessingDescr] = Field(
991 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
992 )
993 """Description of how this output should be postprocessed."""
995 @model_validator(mode="after")
996 def matching_halo_length(self) -> Self:
997 if self.halo and len(self.halo) != len(self.shape):
998 raise ValueError(
999 f"halo {self.halo} has to have same length as shape {self.shape}!"
1000 )
1002 return self
1004 @model_validator(mode="after")
1005 def validate_postprocessing_kwargs(self) -> Self:
1006 for p in self.postprocessing:
1007 kwargs_axes = p.kwargs.get("axes", "")
1008 if not isinstance(kwargs_axes, str):
1009 raise ValueError(f"Expected {kwargs_axes} to be a string")
1011 if any(a not in self.axes for a in kwargs_axes):
1012 raise ValueError("`kwargs.axes` needs to be subset of axes")
1014 return self
1017KnownRunMode = Literal["deepimagej"]
1020class RunMode(Node):
1021 name: Annotated[
1022 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.")
1023 ]
1024 """Run mode name"""
1026 kwargs: Dict[str, Any] = Field(
1027 default_factory=cast(Callable[[], Dict[str, Any]], dict)
1028 )
1029 """Run mode specific key word arguments"""
1032class LinkedModel(Node):
1033 """Reference to a bioimage.io model."""
1035 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])]
1036 """A valid model `id` from the bioimage.io collection."""
1038 version_number: Optional[int] = None
1039 """version number (n-th published version, not the semantic version) of linked model"""
1042def package_weights(
1043 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr]
1044 handler: SerializerFunctionWrapHandler,
1045 info: SerializationInfo,
1046):
1047 ctxt = packaging_context_var.get()
1048 if ctxt is not None and ctxt.weights_priority_order is not None:
1049 for wf in ctxt.weights_priority_order:
1050 w = getattr(value, wf, None)
1051 if w is not None:
1052 break
1053 else:
1054 raise ValueError(
1055 "None of the weight formats in `weights_priority_order`"
1056 + f" ({ctxt.weights_priority_order}) is present in the given model."
1057 )
1059 assert isinstance(w, Node), type(w)
1060 # construct WeightsDescr with new single weight format entry
1061 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"})
1062 value = value.model_construct(None, **{wf: new_w})
1064 return handler(
1065 value,
1066 info, # pyright: ignore[reportArgumentType] # taken from pydantic docs
1067 )
1070class ModelDescr(GenericModelDescrBase):
1071 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
1073 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
1074 """
1076 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10"
1077 if TYPE_CHECKING:
1078 format_version: Literal["0.4.10"] = "0.4.10"
1079 else:
1080 format_version: Literal["0.4.10"]
1081 """Version of the bioimage.io model description specification used.
1082 When creating a new model always use the latest micro/patch version described here.
1083 The `format_version` is important for any consumer software to understand how to parse the fields.
1084 """
1086 implemented_type: ClassVar[Literal["model"]] = "model"
1087 if TYPE_CHECKING:
1088 type: Literal["model"] = "model"
1089 else:
1090 type: Literal["model"]
1091 """Specialized resource type 'model'"""
1093 id: Optional[ModelId] = None
1094 """bioimage.io-wide unique resource identifier
1095 assigned by bioimage.io; version **un**specific."""
1097 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
1098 List[Author]
1099 ]
1100 """The authors are the creators of the model RDF and the primary points of contact."""
1102 documentation: Annotated[
1103 FileSource_,
1104 Field(
1105 examples=[
1106 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
1107 "README.md",
1108 ],
1109 ),
1110 ]
1111 """URL or relative path to a markdown file with additional documentation.
1112 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
1113 The documentation should include a '[#[#]]# Validation' (sub)section
1114 with details on how to quantitatively validate the model on unseen data."""
1116 inputs: NotEmpty[List[InputTensorDescr]]
1117 """Describes the input tensors expected by this model."""
1119 license: Annotated[
1120 Union[LicenseId, str],
1121 warn(LicenseId, "Unknown license id '{value}'."),
1122 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]),
1123 ]
1124 """A [SPDX license identifier](https://spdx.org/licenses/).
1125 We do notsupport custom license beyond the SPDX license list, if you need that please
1126 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose
1127 ) to discuss your intentions with the community."""
1129 name: Annotated[
1130 str,
1131 MinLen(1),
1132 warn(MinLen(5), "Name shorter than 5 characters.", INFO),
1133 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
1134 ]
1135 """A human-readable name of this model.
1136 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters."""
1138 outputs: NotEmpty[List[OutputTensorDescr]]
1139 """Describes the output tensors."""
1141 @field_validator("inputs", "outputs")
1142 @classmethod
1143 def unique_tensor_descr_names(
1144 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]]
1145 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]:
1146 unique_names = {str(v.name) for v in value}
1147 if len(unique_names) != len(value):
1148 raise ValueError("Duplicate tensor descriptor names")
1150 return value
1152 @model_validator(mode="after")
1153 def unique_io_names(self) -> Self:
1154 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s}
1155 if len(unique_names) != (len(self.inputs) + len(self.outputs)):
1156 raise ValueError("Duplicate tensor descriptor names across inputs/outputs")
1158 return self
1160 @model_validator(mode="after")
1161 def minimum_shape2valid_output(self) -> Self:
1162 tensors_by_name: Dict[
1163 TensorName, Union[InputTensorDescr, OutputTensorDescr]
1164 ] = {t.name: t for t in self.inputs + self.outputs}
1166 for out in self.outputs:
1167 if isinstance(out.shape, ImplicitOutputShape):
1168 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape)
1169 ndim_out_ref = len(
1170 [scale for scale in out.shape.scale if scale is not None]
1171 )
1172 if ndim_ref != ndim_out_ref:
1173 expanded_dim_note = (
1174 " Note that expanded dimensions (`scale`: null) are not"
1175 + f" counted for {out.name}'sdimensionality here."
1176 if None in out.shape.scale
1177 else ""
1178 )
1179 raise ValueError(
1180 f"Referenced tensor '{out.shape.reference_tensor}' with"
1181 + f" {ndim_ref} dimensions does not match output tensor"
1182 + f" '{out.name}' with"
1183 + f" {ndim_out_ref} dimensions.{expanded_dim_note}"
1184 )
1186 min_out_shape = self._get_min_shape(out, tensors_by_name)
1187 if out.halo:
1188 halo = out.halo
1189 halo_msg = f" for halo {out.halo}"
1190 else:
1191 halo = [0] * len(min_out_shape)
1192 halo_msg = ""
1194 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]):
1195 raise ValueError(
1196 f"Minimal shape {min_out_shape} of output {out.name} is too"
1197 + f" small{halo_msg}."
1198 )
1200 return self
1202 @classmethod
1203 def _get_min_shape(
1204 cls,
1205 t: Union[InputTensorDescr, OutputTensorDescr],
1206 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]],
1207 ) -> Sequence[int]:
1208 """output with subtracted halo has to result in meaningful output even for the minimal input
1209 see https://github.com/bioimage-io/spec-bioimage-io/issues/392
1210 """
1211 if isinstance(t.shape, collections.abc.Sequence):
1212 return t.shape
1213 elif isinstance(t.shape, ParameterizedInputShape):
1214 return t.shape.min
1215 elif isinstance(t.shape, ImplicitOutputShape):
1216 pass
1217 else:
1218 assert_never(t.shape)
1220 ref_shape = cls._get_min_shape(
1221 tensors_by_name[t.shape.reference_tensor], tensors_by_name
1222 )
1224 if None not in t.shape.scale:
1225 scale: Sequence[float, ...] = t.shape.scale # type: ignore
1226 else:
1227 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None]
1228 new_ref_shape: List[int] = []
1229 for idx in range(len(t.shape.scale)):
1230 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims)
1231 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx])
1233 ref_shape = new_ref_shape
1234 assert len(ref_shape) == len(t.shape.scale)
1235 scale = [0.0 if sc is None else sc for sc in t.shape.scale]
1237 offset = t.shape.offset
1238 assert len(offset) == len(scale)
1239 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)]
1241 @model_validator(mode="after")
1242 def validate_tensor_references_in_inputs(self) -> Self:
1243 for t in self.inputs:
1244 for proc in t.preprocessing:
1245 if "reference_tensor" not in proc.kwargs:
1246 continue
1248 ref_tensor = proc.kwargs["reference_tensor"]
1249 if ref_tensor is not None and str(ref_tensor) not in {
1250 str(t.name) for t in self.inputs
1251 }:
1252 raise ValueError(f"'{ref_tensor}' not found in inputs")
1254 if ref_tensor == t.name:
1255 raise ValueError(
1256 f"invalid self reference for preprocessing of tensor {t.name}"
1257 )
1259 return self
1261 @model_validator(mode="after")
1262 def validate_tensor_references_in_outputs(self) -> Self:
1263 for t in self.outputs:
1264 for proc in t.postprocessing:
1265 if "reference_tensor" not in proc.kwargs:
1266 continue
1267 ref_tensor = proc.kwargs["reference_tensor"]
1268 if ref_tensor is not None and str(ref_tensor) not in {
1269 str(t.name) for t in self.inputs
1270 }:
1271 raise ValueError(f"{ref_tensor} not found in inputs")
1273 return self
1275 packaged_by: List[Author] = Field(
1276 default_factory=cast(Callable[[], List[Author]], list)
1277 )
1278 """The persons that have packaged and uploaded this model.
1279 Only required if those persons differ from the `authors`."""
1281 parent: Optional[LinkedModel] = None
1282 """The model from which this model is derived, e.g. by fine-tuning the weights."""
1284 @field_validator("parent", mode="before")
1285 @classmethod
1286 def ignore_url_parent(cls, parent: Any):
1287 if isinstance(parent, dict):
1288 return None
1290 else:
1291 return parent
1293 run_mode: Optional[RunMode] = None
1294 """Custom run mode for this model: for more complex prediction procedures like test time
1295 data augmentation that currently cannot be expressed in the specification.
1296 No standard run modes are defined yet."""
1298 sample_inputs: List[FileSource_] = Field(
1299 default_factory=cast(Callable[[], List[FileSource_]], list)
1300 )
1301 """URLs/relative paths to sample inputs to illustrate possible inputs for the model,
1302 for example stored as PNG or TIFF images.
1303 The sample files primarily serve to inform a human user about an example use case"""
1305 sample_outputs: List[FileSource_] = Field(
1306 default_factory=cast(Callable[[], List[FileSource_]], list)
1307 )
1308 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`."""
1310 test_inputs: NotEmpty[
1311 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1312 ]
1313 """Test input tensors compatible with the `inputs` description for a **single test case**.
1314 This means if your model has more than one input, you should provide one URL/relative path for each input.
1315 Each test input should be a file with an ndarray in
1316 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1317 The extension must be '.npy'."""
1319 test_outputs: NotEmpty[
1320 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]]
1321 ]
1322 """Analog to `test_inputs`."""
1324 timestamp: Datetime
1325 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
1326 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat)."""
1328 training_data: Union[LinkedDataset, DatasetDescr, None] = None
1329 """The dataset used to train this model"""
1331 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
1332 """The weights for this model.
1333 Weights can be given for different formats, but should otherwise be equivalent.
1334 The available weight formats determine which consumers can use this model."""
1336 @model_validator(mode="before")
1337 @classmethod
1338 def _convert_from_older_format(
1339 cls, data: BioimageioYamlContent, /
1340 ) -> BioimageioYamlContent:
1341 convert_from_older_format(data)
1342 return data
1344 def get_input_test_arrays(self) -> List[NDArray[Any]]:
1345 data = [load_array(ipt) for ipt in self.test_inputs]
1346 assert all(isinstance(d, np.ndarray) for d in data)
1347 return data
1349 def get_output_test_arrays(self) -> List[NDArray[Any]]:
1350 data = [load_array(out) for out in self.test_outputs]
1351 assert all(isinstance(d, np.ndarray) for d in data)
1352 return data