Coverage for bioimageio/core/stat_measures.py: 95%
98 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import (
5 Any,
6 Dict,
7 Literal,
8 Mapping,
9 Optional,
10 Protocol,
11 Tuple,
12 TypeVar,
13 Union,
14)
16import numpy as np
17from pydantic import (
18 BaseModel,
19 BeforeValidator,
20 Discriminator,
21 PlainSerializer,
22)
23from typing_extensions import Annotated
25from .axis import AxisId
26from .common import MemberId, PerMember
27from .tensor import Tensor
30def tensor_custom_before_validator(data: Union[Tensor, Mapping[str, Any]]):
31 if isinstance(data, Tensor):
32 return data
34 # custom before validation logic
35 return Tensor(np.asarray(data["data"]), dims=data["dims"])
38def tensor_custom_serializer(t: Tensor) -> Dict[str, Any]:
39 # custome serialization logic
40 return {"data": t.data.data.tolist(), "dims": list(map(str, t.dims))}
43MeasureValue = Union[
44 float,
45 Annotated[
46 Tensor,
47 BeforeValidator(tensor_custom_before_validator),
48 PlainSerializer(tensor_custom_serializer),
49 ],
50]
53# using Sample Protocol really only to avoid circular imports
54class SampleLike(Protocol):
55 @property
56 def members(self) -> PerMember[Tensor]: ...
59class MeasureBase(BaseModel, frozen=True):
60 member_id: MemberId
63class SampleMeasureBase(MeasureBase, ABC, frozen=True):
64 scope: Literal["sample"] = "sample"
66 @abstractmethod
67 def compute(self, sample: SampleLike) -> MeasureValue:
68 """compute the measure"""
69 ...
72class DatasetMeasureBase(MeasureBase, ABC, frozen=True):
73 scope: Literal["dataset"] = "dataset"
76class _Mean(BaseModel, frozen=True):
77 name: Literal["mean"] = "mean"
78 axes: Optional[Tuple[AxisId, ...]] = None
79 """`axes` to reduce"""
82class SampleMean(_Mean, SampleMeasureBase, frozen=True):
83 """The mean value of a single tensor"""
85 def compute(self, sample: SampleLike) -> MeasureValue:
86 tensor = sample.members[self.member_id]
87 return tensor.mean(dim=self.axes)
89 def model_post_init(self, __context: Any):
90 assert self.axes is None or AxisId("batch") not in self.axes
93class DatasetMean(_Mean, DatasetMeasureBase, frozen=True):
94 """The mean value across multiple samples"""
96 def model_post_init(self, __context: Any):
97 assert self.axes is None or AxisId("batch") in self.axes
100class _Std(BaseModel, frozen=True):
101 name: Literal["std"] = "std"
102 axes: Optional[Tuple[AxisId, ...]] = None
103 """`axes` to reduce"""
106class SampleStd(_Std, SampleMeasureBase, frozen=True):
107 """The standard deviation of a single tensor"""
109 def compute(self, sample: SampleLike) -> MeasureValue:
110 tensor = sample.members[self.member_id]
111 return tensor.std(dim=self.axes)
113 def model_post_init(self, __context: Any):
114 assert self.axes is None or AxisId("batch") not in self.axes
117class DatasetStd(_Std, DatasetMeasureBase, frozen=True):
118 """The standard deviation across multiple samples"""
120 def model_post_init(self, __context: Any):
121 assert self.axes is None or AxisId("batch") in self.axes
124class _Var(BaseModel, frozen=True):
125 name: Literal["var"] = "var"
126 axes: Optional[Tuple[AxisId, ...]] = None
127 """`axes` to reduce"""
130class SampleVar(_Var, SampleMeasureBase, frozen=True):
131 """The variance of a single tensor"""
133 def compute(self, sample: SampleLike) -> MeasureValue:
134 tensor = sample.members[self.member_id]
135 return tensor.var(dim=self.axes)
137 def model_post_init(self, __context: Any):
138 assert self.axes is None or AxisId("batch") not in self.axes
141class DatasetVar(_Var, DatasetMeasureBase, frozen=True):
142 """The variance across multiple samples"""
144 def model_post_init(self, __context: Any): # TODO: turn into @model_validator
145 assert self.axes is None or AxisId("batch") in self.axes
148class _Quantile(BaseModel, frozen=True):
149 name: Literal["quantile"] = "quantile"
150 q: float
151 axes: Optional[Tuple[AxisId, ...]] = None
152 """`axes` to reduce"""
154 def model_post_init(self, __context: Any):
155 assert self.q >= 0.0
156 assert self.q <= 1.0
159class SampleQuantile(_Quantile, SampleMeasureBase, frozen=True):
160 """The `n`th percentile of a single tensor"""
162 def compute(self, sample: SampleLike) -> MeasureValue:
163 tensor = sample.members[self.member_id]
164 return tensor.quantile(self.q, dim=self.axes)
166 def model_post_init(self, __context: Any):
167 super().model_post_init(__context)
168 assert self.axes is None or AxisId("batch") not in self.axes
171class DatasetPercentile(_Quantile, DatasetMeasureBase, frozen=True):
172 """The `n`th percentile across multiple samples"""
174 def model_post_init(self, __context: Any):
175 super().model_post_init(__context)
176 assert self.axes is None or AxisId("batch") in self.axes
179SampleMeasure = Annotated[
180 Union[SampleMean, SampleStd, SampleVar, SampleQuantile], Discriminator("name")
181]
182DatasetMeasure = Annotated[
183 Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile], Discriminator("name")
184]
185Measure = Annotated[Union[SampleMeasure, DatasetMeasure], Discriminator("scope")]
186Stat = Dict[Measure, MeasureValue]
188MeanMeasure = Union[SampleMean, DatasetMean]
189StdMeasure = Union[SampleStd, DatasetStd]
190VarMeasure = Union[SampleVar, DatasetVar]
191PercentileMeasure = Union[SampleQuantile, DatasetPercentile]
192MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure)
193StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure)
194VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure)
195PercentileMeasureT = TypeVar("PercentileMeasureT", bound=PercentileMeasure)