Coverage for bioimageio/core/proc_setup.py: 92%

91 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from typing import ( 

2 Callable, 

3 Iterable, 

4 List, 

5 Mapping, 

6 NamedTuple, 

7 Optional, 

8 Sequence, 

9 Set, 

10 Union, 

11) 

12 

13from typing_extensions import assert_never 

14 

15from bioimageio.core.digest_spec import get_member_id 

16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

17 

18from .proc_ops import ( 

19 AddKnownDatasetStats, 

20 EnsureDtype, 

21 Processing, 

22 UpdateStats, 

23 get_proc, 

24) 

25from .sample import Sample 

26from .stat_calculators import StatsCalculator 

27from .stat_measures import ( 

28 DatasetMeasure, 

29 DatasetMeasureBase, 

30 Measure, 

31 MeasureValue, 

32 SampleMeasure, 

33 SampleMeasureBase, 

34) 

35 

36TensorDescr = Union[ 

37 v0_4.InputTensorDescr, 

38 v0_4.OutputTensorDescr, 

39 v0_5.InputTensorDescr, 

40 v0_5.OutputTensorDescr, 

41] 

42 

43 

44class PreAndPostprocessing(NamedTuple): 

45 pre: List[Processing] 

46 post: List[Processing] 

47 

48 

49class _ProcessingCallables(NamedTuple): 

50 pre: Callable[[Sample], None] 

51 post: Callable[[Sample], None] 

52 

53 

54class _SetupProcessing(NamedTuple): 

55 pre: List[Processing] 

56 post: List[Processing] 

57 pre_measures: Set[Measure] 

58 post_measures: Set[Measure] 

59 

60 

61class _ApplyProcs: 

62 def __init__(self, procs: Sequence[Processing]): 

63 super().__init__() 

64 self._procs = procs 

65 

66 def __call__(self, sample: Sample) -> None: 

67 for op in self._procs: 

68 op(sample) 

69 

70 

71def get_pre_and_postprocessing( 

72 model: AnyModelDescr, 

73 *, 

74 dataset_for_initial_statistics: Iterable[Sample], 

75 keep_updating_initial_dataset_stats: bool = False, 

76 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

77) -> _ProcessingCallables: 

78 """Creates callables to apply pre- and postprocessing in-place to a sample""" 

79 

80 setup = setup_pre_and_postprocessing( 

81 model=model, 

82 dataset_for_initial_statistics=dataset_for_initial_statistics, 

83 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

84 fixed_dataset_stats=fixed_dataset_stats, 

85 ) 

86 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post)) 

87 

88 

89def setup_pre_and_postprocessing( 

90 model: AnyModelDescr, 

91 dataset_for_initial_statistics: Iterable[Sample], 

92 keep_updating_initial_dataset_stats: bool = False, 

93 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

94) -> PreAndPostprocessing: 

95 """ 

96 Get pre- and postprocessing operators for a `model` description. 

97 Used in `bioimageio.core.create_prediction_pipeline""" 

98 prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) 

99 

100 missing_dataset_stats = { 

101 m 

102 for m in prep_meas | post_meas 

103 if fixed_dataset_stats is None or m not in fixed_dataset_stats 

104 } 

105 if missing_dataset_stats: 

106 initial_stats_calc = StatsCalculator(missing_dataset_stats) 

107 for sample in dataset_for_initial_statistics: 

108 initial_stats_calc.update(sample) 

109 

110 initial_stats = initial_stats_calc.finalize() 

111 else: 

112 initial_stats = {} 

113 

114 prep.insert( 

115 0, 

116 UpdateStats( 

117 StatsCalculator(prep_meas, initial_stats), 

118 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

119 ), 

120 ) 

121 if post_meas: 

122 post.insert( 

123 0, 

124 UpdateStats( 

125 StatsCalculator(post_meas, initial_stats), 

126 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

127 ), 

128 ) 

129 

130 if fixed_dataset_stats: 

131 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

132 post.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

133 

134 return PreAndPostprocessing(prep, post) 

135 

136 

137class RequiredMeasures(NamedTuple): 

138 pre: Set[Measure] 

139 post: Set[Measure] 

140 

141 

142class RequiredDatasetMeasures(NamedTuple): 

143 pre: Set[DatasetMeasure] 

144 post: Set[DatasetMeasure] 

145 

146 

147class RequiredSampleMeasures(NamedTuple): 

148 pre: Set[SampleMeasure] 

149 post: Set[SampleMeasure] 

150 

151 

152def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures: 

153 s = _prepare_setup_pre_and_postprocessing(model) 

154 return RequiredMeasures(s.pre_measures, s.post_measures) 

155 

156 

157def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures: 

158 s = _prepare_setup_pre_and_postprocessing(model) 

159 return RequiredDatasetMeasures( 

160 {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)}, 

161 {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)}, 

162 ) 

163 

164 

165def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures: 

166 s = _prepare_setup_pre_and_postprocessing(model) 

167 return RequiredSampleMeasures( 

168 {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)}, 

169 {m for m in s.post_measures if isinstance(m, SampleMeasureBase)}, 

170 ) 

171 

172 

173def _prepare_procs( 

174 tensor_descrs: Union[ 

175 Sequence[v0_4.InputTensorDescr], 

176 Sequence[v0_5.InputTensorDescr], 

177 Sequence[v0_4.OutputTensorDescr], 

178 Sequence[v0_5.OutputTensorDescr], 

179 ], 

180) -> List[Processing]: 

181 procs: List[Processing] = [] 

182 for t_descr in tensor_descrs: 

183 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

184 member_id = get_member_id(t_descr) 

185 procs.append( 

186 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) 

187 ) 

188 

189 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): 

190 for proc_d in t_descr.preprocessing: 

191 procs.append(get_proc(proc_d, t_descr)) 

192 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)): 

193 for proc_d in t_descr.postprocessing: 

194 procs.append(get_proc(proc_d, t_descr)) 

195 else: 

196 assert_never(t_descr) 

197 

198 if isinstance( 

199 t_descr, 

200 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)), 

201 ): 

202 if len(procs) == 1: 

203 # remove initial ensure_dtype if there are no other proccessing steps 

204 assert isinstance(procs[0], EnsureDtype) 

205 procs = [] 

206 

207 # ensure 0.4 models get float32 input 

208 # which has been the implicit assumption for 0.4 

209 member_id = get_member_id(t_descr) 

210 procs.append( 

211 EnsureDtype(input=member_id, output=member_id, dtype="float32") 

212 ) 

213 

214 return procs 

215 

216 

217def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: 

218 if isinstance(model, v0_4.ModelDescr): 

219 pre = _prepare_procs(model.inputs) 

220 post = _prepare_procs(model.outputs) 

221 elif isinstance(model, v0_5.ModelDescr): 

222 pre = _prepare_procs(model.inputs) 

223 post = _prepare_procs(model.outputs) 

224 else: 

225 assert_never(model) 

226 

227 return _SetupProcessing( 

228 pre=pre, 

229 post=post, 

230 pre_measures={m for proc in pre for m in proc.required_measures}, 

231 post_measures={m for proc in post for m in proc.required_measures}, 

232 )