Coverage for bioimageio/core/proc_setup.py: 87%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from typing import ( 

2 Iterable, 

3 List, 

4 Mapping, 

5 NamedTuple, 

6 Optional, 

7 Sequence, 

8 Set, 

9 Union, 

10) 

11 

12from typing_extensions import assert_never 

13 

14from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

15from bioimageio.spec.model.v0_5 import TensorId 

16 

17from .digest_spec import get_member_ids 

18from .proc_ops import ( 

19 AddKnownDatasetStats, 

20 Processing, 

21 UpdateStats, 

22 get_proc_class, 

23) 

24from .sample import Sample 

25from .stat_calculators import StatsCalculator 

26from .stat_measures import ( 

27 DatasetMeasure, 

28 DatasetMeasureBase, 

29 Measure, 

30 MeasureValue, 

31 SampleMeasure, 

32 SampleMeasureBase, 

33) 

34 

35TensorDescr = Union[ 

36 v0_4.InputTensorDescr, 

37 v0_4.OutputTensorDescr, 

38 v0_5.InputTensorDescr, 

39 v0_5.OutputTensorDescr, 

40] 

41 

42 

43class PreAndPostprocessing(NamedTuple): 

44 pre: List[Processing] 

45 post: List[Processing] 

46 

47 

48class _SetupProcessing(NamedTuple): 

49 pre: List[Processing] 

50 post: List[Processing] 

51 pre_measures: Set[Measure] 

52 post_measures: Set[Measure] 

53 

54 

55def setup_pre_and_postprocessing( 

56 model: AnyModelDescr, 

57 dataset_for_initial_statistics: Iterable[Sample], 

58 keep_updating_initial_dataset_stats: bool = False, 

59 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

60) -> PreAndPostprocessing: 

61 """ 

62 Get pre- and postprocessing operators for a `model` description. 

63 userd in `bioimageio.core.create_prediction_pipeline""" 

64 prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) 

65 

66 missing_dataset_stats = { 

67 m 

68 for m in prep_meas | post_meas 

69 if fixed_dataset_stats is None or m not in fixed_dataset_stats 

70 } 

71 if missing_dataset_stats: 

72 initial_stats_calc = StatsCalculator(missing_dataset_stats) 

73 for sample in dataset_for_initial_statistics: 

74 initial_stats_calc.update(sample) 

75 

76 initial_stats = initial_stats_calc.finalize() 

77 else: 

78 initial_stats = {} 

79 

80 prep.insert( 

81 0, 

82 UpdateStats( 

83 StatsCalculator(prep_meas, initial_stats), 

84 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

85 ), 

86 ) 

87 if post_meas: 

88 post.insert( 

89 0, 

90 UpdateStats( 

91 StatsCalculator(post_meas, initial_stats), 

92 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

93 ), 

94 ) 

95 

96 if fixed_dataset_stats: 

97 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

98 post.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

99 

100 return PreAndPostprocessing(prep, post) 

101 

102 

103class RequiredMeasures(NamedTuple): 

104 pre: Set[Measure] 

105 post: Set[Measure] 

106 

107 

108class RequiredDatasetMeasures(NamedTuple): 

109 pre: Set[DatasetMeasure] 

110 post: Set[DatasetMeasure] 

111 

112 

113class RequiredSampleMeasures(NamedTuple): 

114 pre: Set[SampleMeasure] 

115 post: Set[SampleMeasure] 

116 

117 

118def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures: 

119 s = _prepare_setup_pre_and_postprocessing(model) 

120 return RequiredMeasures(s.pre_measures, s.post_measures) 

121 

122 

123def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures: 

124 s = _prepare_setup_pre_and_postprocessing(model) 

125 return RequiredDatasetMeasures( 

126 {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)}, 

127 {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)}, 

128 ) 

129 

130 

131def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures: 

132 s = _prepare_setup_pre_and_postprocessing(model) 

133 return RequiredSampleMeasures( 

134 {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)}, 

135 {m for m in s.post_measures if isinstance(m, SampleMeasureBase)}, 

136 ) 

137 

138 

139def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: 

140 pre_measures: Set[Measure] = set() 

141 post_measures: Set[Measure] = set() 

142 

143 input_ids = set(get_member_ids(model.inputs)) 

144 output_ids = set(get_member_ids(model.outputs)) 

145 

146 def prepare_procs(tensor_descrs: Sequence[TensorDescr]): 

147 procs: List[Processing] = [] 

148 for t_descr in tensor_descrs: 

149 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): 

150 proc_descrs: List[ 

151 Union[ 

152 v0_4.PreprocessingDescr, 

153 v0_5.PreprocessingDescr, 

154 v0_4.PostprocessingDescr, 

155 v0_5.PostprocessingDescr, 

156 ] 

157 ] = list(t_descr.preprocessing) 

158 elif isinstance( 

159 t_descr, 

160 (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), 

161 ): 

162 proc_descrs = list(t_descr.postprocessing) 

163 else: 

164 assert_never(t_descr) 

165 

166 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

167 ensure_dtype = v0_5.EnsureDtypeDescr( 

168 kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type) 

169 ) 

170 if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs: 

171 proc_descrs.insert(0, ensure_dtype) 

172 

173 proc_descrs.append(ensure_dtype) 

174 

175 for proc_d in proc_descrs: 

176 proc_class = get_proc_class(proc_d) 

177 member_id = ( 

178 TensorId(str(t_descr.name)) 

179 if isinstance(t_descr, v0_4.TensorDescrBase) 

180 else t_descr.id 

181 ) 

182 req = proc_class.from_proc_descr( 

183 proc_d, member_id # pyright: ignore[reportArgumentType] 

184 ) 

185 for m in req.required_measures: 

186 if m.member_id in input_ids: 

187 pre_measures.add(m) 

188 elif m.member_id in output_ids: 

189 post_measures.add(m) 

190 else: 

191 raise ValueError("When to raise ") 

192 procs.append(req) 

193 return procs 

194 

195 return _SetupProcessing( 

196 pre=prepare_procs(model.inputs), 

197 post=prepare_procs(model.outputs), 

198 pre_measures=pre_measures, 

199 post_measures=post_measures, 

200 )