bioimageio.core.stat_measures
1from __future__ import annotations 2 3from abc import ABC, abstractmethod 4from typing import ( 5 Any, 6 Dict, 7 Literal, 8 Mapping, 9 Optional, 10 Protocol, 11 Tuple, 12 TypeVar, 13 Union, 14) 15 16import numpy as np 17from pydantic import ( 18 BaseModel, 19 BeforeValidator, 20 Discriminator, 21 PlainSerializer, 22) 23from typing_extensions import Annotated 24 25from .axis import AxisId 26from .common import MemberId, PerMember 27from .tensor import Tensor 28 29 30def tensor_custom_before_validator(data: Union[Tensor, Mapping[str, Any]]): 31 if isinstance(data, Tensor): 32 return data 33 34 # custom before validation logic 35 return Tensor(np.asarray(data["data"]), dims=data["dims"]) 36 37 38def tensor_custom_serializer(t: Tensor) -> Dict[str, Any]: 39 # custome serialization logic 40 return {"data": t.data.data.tolist(), "dims": list(map(str, t.dims))} 41 42 43MeasureValue = Union[ 44 float, 45 Annotated[ 46 Tensor, 47 BeforeValidator(tensor_custom_before_validator), 48 PlainSerializer(tensor_custom_serializer), 49 ], 50] 51 52 53# using Sample Protocol really only to avoid circular imports 54class SampleLike(Protocol): 55 @property 56 def members(self) -> PerMember[Tensor]: ... 57 58 59class MeasureBase(BaseModel, frozen=True): 60 member_id: MemberId 61 62 63class SampleMeasureBase(MeasureBase, ABC, frozen=True): 64 scope: Literal["sample"] = "sample" 65 66 @abstractmethod 67 def compute(self, sample: SampleLike) -> MeasureValue: 68 """compute the measure""" 69 ... 70 71 72class DatasetMeasureBase(MeasureBase, ABC, frozen=True): 73 scope: Literal["dataset"] = "dataset" 74 75 76class _Mean(BaseModel, frozen=True): 77 name: Literal["mean"] = "mean" 78 axes: Optional[Tuple[AxisId, ...]] = None 79 """`axes` to reduce""" 80 81 82class SampleMean(_Mean, SampleMeasureBase, frozen=True): 83 """The mean value of a single tensor""" 84 85 def compute(self, sample: SampleLike) -> MeasureValue: 86 tensor = sample.members[self.member_id] 87 return tensor.mean(dim=self.axes) 88 89 def model_post_init(self, __context: Any): 90 assert self.axes is None or AxisId("batch") not in self.axes 91 92 93class DatasetMean(_Mean, DatasetMeasureBase, frozen=True): 94 """The mean value across multiple samples""" 95 96 def model_post_init(self, __context: Any): 97 assert self.axes is None or AxisId("batch") in self.axes 98 99 100class _Std(BaseModel, frozen=True): 101 name: Literal["std"] = "std" 102 axes: Optional[Tuple[AxisId, ...]] = None 103 """`axes` to reduce""" 104 105 106class SampleStd(_Std, SampleMeasureBase, frozen=True): 107 """The standard deviation of a single tensor""" 108 109 def compute(self, sample: SampleLike) -> MeasureValue: 110 tensor = sample.members[self.member_id] 111 return tensor.std(dim=self.axes) 112 113 def model_post_init(self, __context: Any): 114 assert self.axes is None or AxisId("batch") not in self.axes 115 116 117class DatasetStd(_Std, DatasetMeasureBase, frozen=True): 118 """The standard deviation across multiple samples""" 119 120 def model_post_init(self, __context: Any): 121 assert self.axes is None or AxisId("batch") in self.axes 122 123 124class _Var(BaseModel, frozen=True): 125 name: Literal["var"] = "var" 126 axes: Optional[Tuple[AxisId, ...]] = None 127 """`axes` to reduce""" 128 129 130class SampleVar(_Var, SampleMeasureBase, frozen=True): 131 """The variance of a single tensor""" 132 133 def compute(self, sample: SampleLike) -> MeasureValue: 134 tensor = sample.members[self.member_id] 135 return tensor.var(dim=self.axes) 136 137 def model_post_init(self, __context: Any): 138 assert self.axes is None or AxisId("batch") not in self.axes 139 140 141class DatasetVar(_Var, DatasetMeasureBase, frozen=True): 142 """The variance across multiple samples""" 143 144 def model_post_init(self, __context: Any): # TODO: turn into @model_validator 145 assert self.axes is None or AxisId("batch") in self.axes 146 147 148class _Quantile(BaseModel, frozen=True): 149 name: Literal["quantile"] = "quantile" 150 q: float 151 axes: Optional[Tuple[AxisId, ...]] = None 152 """`axes` to reduce""" 153 154 def model_post_init(self, __context: Any): 155 assert self.q >= 0.0 156 assert self.q <= 1.0 157 158 159class SampleQuantile(_Quantile, SampleMeasureBase, frozen=True): 160 """The `n`th percentile of a single tensor""" 161 162 def compute(self, sample: SampleLike) -> MeasureValue: 163 tensor = sample.members[self.member_id] 164 return tensor.quantile(self.q, dim=self.axes) 165 166 def model_post_init(self, __context: Any): 167 super().model_post_init(__context) 168 assert self.axes is None or AxisId("batch") not in self.axes 169 170 171class DatasetPercentile(_Quantile, DatasetMeasureBase, frozen=True): 172 """The `n`th percentile across multiple samples""" 173 174 def model_post_init(self, __context: Any): 175 super().model_post_init(__context) 176 assert self.axes is None or AxisId("batch") in self.axes 177 178 179SampleMeasure = Annotated[ 180 Union[SampleMean, SampleStd, SampleVar, SampleQuantile], Discriminator("name") 181] 182DatasetMeasure = Annotated[ 183 Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile], Discriminator("name") 184] 185Measure = Annotated[Union[SampleMeasure, DatasetMeasure], Discriminator("scope")] 186Stat = Dict[Measure, MeasureValue] 187 188MeanMeasure = Union[SampleMean, DatasetMean] 189StdMeasure = Union[SampleStd, DatasetStd] 190VarMeasure = Union[SampleVar, DatasetVar] 191PercentileMeasure = Union[SampleQuantile, DatasetPercentile] 192MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure) 193StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure) 194VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure) 195PercentileMeasureT = TypeVar("PercentileMeasureT", bound=PercentileMeasure)
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
1767def _no_init_or_replace_init(self, *args, **kwargs): 1768 cls = type(self) 1769 1770 if cls._is_protocol: 1771 raise TypeError('Protocols cannot be instantiated') 1772 1773 # Already using a custom `__init__`. No need to calculate correct 1774 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1775 if cls.__init__ is not _no_init_or_replace_init: 1776 return 1777 1778 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1779 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1780 # searches for a proper new `__init__` in the MRO. The new `__init__` 1781 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1782 # instantiation of the protocol subclass will thus use the new 1783 # `__init__` and no longer call `_no_init_or_replace_init`. 1784 for base in cls.__mro__: 1785 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1786 if init is not _no_init_or_replace_init: 1787 cls.__init__ = init 1788 break 1789 else: 1790 # should not happen 1791 cls.__init__ = object.__init__ 1792 1793 cls.__init__(self, *args, **kwargs)
Usage docs: https://docs.pydantic.dev/2.9/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
64class SampleMeasureBase(MeasureBase, ABC, frozen=True): 65 scope: Literal["sample"] = "sample" 66 67 @abstractmethod 68 def compute(self, sample: SampleLike) -> MeasureValue: 69 """compute the measure""" 70 ...
Usage docs: https://docs.pydantic.dev/2.9/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
67 @abstractmethod 68 def compute(self, sample: SampleLike) -> MeasureValue: 69 """compute the measure""" 70 ...
compute the measure
Inherited Members
Usage docs: https://docs.pydantic.dev/2.9/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Inherited Members
83class SampleMean(_Mean, SampleMeasureBase, frozen=True): 84 """The mean value of a single tensor""" 85 86 def compute(self, sample: SampleLike) -> MeasureValue: 87 tensor = sample.members[self.member_id] 88 return tensor.mean(dim=self.axes) 89 90 def model_post_init(self, __context: Any): 91 assert self.axes is None or AxisId("batch") not in self.axes
The mean value of a single tensor
86 def compute(self, sample: SampleLike) -> MeasureValue: 87 tensor = sample.members[self.member_id] 88 return tensor.mean(dim=self.axes)
compute the measure
90 def model_post_init(self, __context: Any): 91 assert self.axes is None or AxisId("batch") not in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
94class DatasetMean(_Mean, DatasetMeasureBase, frozen=True): 95 """The mean value across multiple samples""" 96 97 def model_post_init(self, __context: Any): 98 assert self.axes is None or AxisId("batch") in self.axes
The mean value across multiple samples
97 def model_post_init(self, __context: Any): 98 assert self.axes is None or AxisId("batch") in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
107class SampleStd(_Std, SampleMeasureBase, frozen=True): 108 """The standard deviation of a single tensor""" 109 110 def compute(self, sample: SampleLike) -> MeasureValue: 111 tensor = sample.members[self.member_id] 112 return tensor.std(dim=self.axes) 113 114 def model_post_init(self, __context: Any): 115 assert self.axes is None or AxisId("batch") not in self.axes
The standard deviation of a single tensor
110 def compute(self, sample: SampleLike) -> MeasureValue: 111 tensor = sample.members[self.member_id] 112 return tensor.std(dim=self.axes)
compute the measure
114 def model_post_init(self, __context: Any): 115 assert self.axes is None or AxisId("batch") not in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
118class DatasetStd(_Std, DatasetMeasureBase, frozen=True): 119 """The standard deviation across multiple samples""" 120 121 def model_post_init(self, __context: Any): 122 assert self.axes is None or AxisId("batch") in self.axes
The standard deviation across multiple samples
121 def model_post_init(self, __context: Any): 122 assert self.axes is None or AxisId("batch") in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
131class SampleVar(_Var, SampleMeasureBase, frozen=True): 132 """The variance of a single tensor""" 133 134 def compute(self, sample: SampleLike) -> MeasureValue: 135 tensor = sample.members[self.member_id] 136 return tensor.var(dim=self.axes) 137 138 def model_post_init(self, __context: Any): 139 assert self.axes is None or AxisId("batch") not in self.axes
The variance of a single tensor
134 def compute(self, sample: SampleLike) -> MeasureValue: 135 tensor = sample.members[self.member_id] 136 return tensor.var(dim=self.axes)
compute the measure
138 def model_post_init(self, __context: Any): 139 assert self.axes is None or AxisId("batch") not in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
142class DatasetVar(_Var, DatasetMeasureBase, frozen=True): 143 """The variance across multiple samples""" 144 145 def model_post_init(self, __context: Any): # TODO: turn into @model_validator 146 assert self.axes is None or AxisId("batch") in self.axes
The variance across multiple samples
145 def model_post_init(self, __context: Any): # TODO: turn into @model_validator 146 assert self.axes is None or AxisId("batch") in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
160class SampleQuantile(_Quantile, SampleMeasureBase, frozen=True): 161 """The `n`th percentile of a single tensor""" 162 163 def compute(self, sample: SampleLike) -> MeasureValue: 164 tensor = sample.members[self.member_id] 165 return tensor.quantile(self.q, dim=self.axes) 166 167 def model_post_init(self, __context: Any): 168 super().model_post_init(__context) 169 assert self.axes is None or AxisId("batch") not in self.axes
The n
th percentile of a single tensor
163 def compute(self, sample: SampleLike) -> MeasureValue: 164 tensor = sample.members[self.member_id] 165 return tensor.quantile(self.q, dim=self.axes)
compute the measure
167 def model_post_init(self, __context: Any): 168 super().model_post_init(__context) 169 assert self.axes is None or AxisId("batch") not in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
172class DatasetPercentile(_Quantile, DatasetMeasureBase, frozen=True): 173 """The `n`th percentile across multiple samples""" 174 175 def model_post_init(self, __context: Any): 176 super().model_post_init(__context) 177 assert self.axes is None or AxisId("batch") in self.axes
The n
th percentile across multiple samples
175 def model_post_init(self, __context: Any): 176 super().model_post_init(__context) 177 assert self.axes is None or AxisId("batch") in self.axes
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.