Coverage for bioimageio/core/stat_measures.py: 95%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from typing import ( 

5 Any, 

6 Dict, 

7 Literal, 

8 Mapping, 

9 Optional, 

10 Protocol, 

11 Tuple, 

12 TypeVar, 

13 Union, 

14) 

15 

16import numpy as np 

17from pydantic import ( 

18 BaseModel, 

19 BeforeValidator, 

20 Discriminator, 

21 PlainSerializer, 

22) 

23from typing_extensions import Annotated 

24 

25from .axis import AxisId 

26from .common import MemberId, PerMember 

27from .tensor import Tensor 

28 

29 

30def tensor_custom_before_validator(data: Union[Tensor, Mapping[str, Any]]): 

31 if isinstance(data, Tensor): 

32 return data 

33 

34 # custom before validation logic 

35 return Tensor(np.asarray(data["data"]), dims=data["dims"]) 

36 

37 

38def tensor_custom_serializer(t: Tensor) -> Dict[str, Any]: 

39 # custome serialization logic 

40 return {"data": t.data.data.tolist(), "dims": list(map(str, t.dims))} 

41 

42 

43MeasureValue = Union[ 

44 float, 

45 Annotated[ 

46 Tensor, 

47 BeforeValidator(tensor_custom_before_validator), 

48 PlainSerializer(tensor_custom_serializer), 

49 ], 

50] 

51 

52 

53# using Sample Protocol really only to avoid circular imports 

54class SampleLike(Protocol): 

55 @property 

56 def members(self) -> PerMember[Tensor]: ... 

57 

58 

59class MeasureBase(BaseModel, frozen=True): 

60 member_id: MemberId 

61 

62 

63class SampleMeasureBase(MeasureBase, ABC, frozen=True): 

64 scope: Literal["sample"] = "sample" 

65 

66 @abstractmethod 

67 def compute(self, sample: SampleLike) -> MeasureValue: 

68 """compute the measure""" 

69 ... 

70 

71 

72class DatasetMeasureBase(MeasureBase, ABC, frozen=True): 

73 scope: Literal["dataset"] = "dataset" 

74 

75 

76class _Mean(BaseModel, frozen=True): 

77 name: Literal["mean"] = "mean" 

78 axes: Optional[Tuple[AxisId, ...]] = None 

79 """`axes` to reduce""" 

80 

81 

82class SampleMean(_Mean, SampleMeasureBase, frozen=True): 

83 """The mean value of a single tensor""" 

84 

85 def compute(self, sample: SampleLike) -> MeasureValue: 

86 tensor = sample.members[self.member_id] 

87 return tensor.mean(dim=self.axes) 

88 

89 def model_post_init(self, __context: Any): 

90 assert self.axes is None or AxisId("batch") not in self.axes 

91 

92 

93class DatasetMean(_Mean, DatasetMeasureBase, frozen=True): 

94 """The mean value across multiple samples""" 

95 

96 def model_post_init(self, __context: Any): 

97 assert self.axes is None or AxisId("batch") in self.axes 

98 

99 

100class _Std(BaseModel, frozen=True): 

101 name: Literal["std"] = "std" 

102 axes: Optional[Tuple[AxisId, ...]] = None 

103 """`axes` to reduce""" 

104 

105 

106class SampleStd(_Std, SampleMeasureBase, frozen=True): 

107 """The standard deviation of a single tensor""" 

108 

109 def compute(self, sample: SampleLike) -> MeasureValue: 

110 tensor = sample.members[self.member_id] 

111 return tensor.std(dim=self.axes) 

112 

113 def model_post_init(self, __context: Any): 

114 assert self.axes is None or AxisId("batch") not in self.axes 

115 

116 

117class DatasetStd(_Std, DatasetMeasureBase, frozen=True): 

118 """The standard deviation across multiple samples""" 

119 

120 def model_post_init(self, __context: Any): 

121 assert self.axes is None or AxisId("batch") in self.axes 

122 

123 

124class _Var(BaseModel, frozen=True): 

125 name: Literal["var"] = "var" 

126 axes: Optional[Tuple[AxisId, ...]] = None 

127 """`axes` to reduce""" 

128 

129 

130class SampleVar(_Var, SampleMeasureBase, frozen=True): 

131 """The variance of a single tensor""" 

132 

133 def compute(self, sample: SampleLike) -> MeasureValue: 

134 tensor = sample.members[self.member_id] 

135 return tensor.var(dim=self.axes) 

136 

137 def model_post_init(self, __context: Any): 

138 assert self.axes is None or AxisId("batch") not in self.axes 

139 

140 

141class DatasetVar(_Var, DatasetMeasureBase, frozen=True): 

142 """The variance across multiple samples""" 

143 

144 def model_post_init(self, __context: Any): # TODO: turn into @model_validator 

145 assert self.axes is None or AxisId("batch") in self.axes 

146 

147 

148class _Quantile(BaseModel, frozen=True): 

149 name: Literal["quantile"] = "quantile" 

150 q: float 

151 axes: Optional[Tuple[AxisId, ...]] = None 

152 """`axes` to reduce""" 

153 

154 def model_post_init(self, __context: Any): 

155 assert self.q >= 0.0 

156 assert self.q <= 1.0 

157 

158 

159class SampleQuantile(_Quantile, SampleMeasureBase, frozen=True): 

160 """The `n`th percentile of a single tensor""" 

161 

162 def compute(self, sample: SampleLike) -> MeasureValue: 

163 tensor = sample.members[self.member_id] 

164 return tensor.quantile(self.q, dim=self.axes) 

165 

166 def model_post_init(self, __context: Any): 

167 super().model_post_init(__context) 

168 assert self.axes is None or AxisId("batch") not in self.axes 

169 

170 

171class DatasetPercentile(_Quantile, DatasetMeasureBase, frozen=True): 

172 """The `n`th percentile across multiple samples""" 

173 

174 def model_post_init(self, __context: Any): 

175 super().model_post_init(__context) 

176 assert self.axes is None or AxisId("batch") in self.axes 

177 

178 

179SampleMeasure = Annotated[ 

180 Union[SampleMean, SampleStd, SampleVar, SampleQuantile], Discriminator("name") 

181] 

182DatasetMeasure = Annotated[ 

183 Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile], Discriminator("name") 

184] 

185Measure = Annotated[Union[SampleMeasure, DatasetMeasure], Discriminator("scope")] 

186Stat = Dict[Measure, MeasureValue] 

187 

188MeanMeasure = Union[SampleMean, DatasetMean] 

189StdMeasure = Union[SampleStd, DatasetStd] 

190VarMeasure = Union[SampleVar, DatasetVar] 

191PercentileMeasure = Union[SampleQuantile, DatasetPercentile] 

192MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure) 

193StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure) 

194VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure) 

195PercentileMeasureT = TypeVar("PercentileMeasureT", bound=PercentileMeasure)