Coverage for src / bioimageio / core / proc_setup.py: 95%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1from itertools import chain 

2from typing import ( 

3 Callable, 

4 Iterable, 

5 List, 

6 Mapping, 

7 NamedTuple, 

8 Optional, 

9 Sequence, 

10 Set, 

11 Union, 

12) 

13 

14from typing_extensions import assert_never 

15 

16from bioimageio.core.digest_spec import get_member_id 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18 

19# CustomPostprocessingDescr was added after spec 0.5.9.1; guard for older installs 

20_CustomPostprocessingDescr = getattr(v0_5, "CustomPostprocessingDescr", None) 

21 

22from .proc_ops import ( 

23 AddKnownDatasetStats, 

24 CustomPostprocessing, 

25 EnsureDtype, 

26 Processing, 

27 UpdateStats, 

28 get_proc, 

29) 

30from .sample import Sample 

31from .stat_calculators import StatsCalculator 

32from .stat_measures import ( 

33 DatasetMeasure, 

34 DatasetMeasureBase, 

35 Measure, 

36 MeasureValue, 

37 SampleMeasure, 

38 SampleMeasureBase, 

39) 

40 

41TensorDescr = Union[ 

42 v0_4.InputTensorDescr, 

43 v0_4.OutputTensorDescr, 

44 v0_5.InputTensorDescr, 

45 v0_5.OutputTensorDescr, 

46] 

47 

48 

49class PreAndPostprocessing(NamedTuple): 

50 pre: List[Processing] 

51 post: List[Processing] 

52 

53 

54class _ProcessingCallables(NamedTuple): 

55 pre: Callable[[Sample], None] 

56 post: Callable[[Sample], None] 

57 

58 

59class _ApplyProcs: 

60 def __init__(self, procs: Sequence[Processing]): 

61 super().__init__() 

62 self._procs = procs 

63 

64 def __call__(self, sample: Sample) -> None: 

65 for op in self._procs: 

66 op(sample) 

67 

68 

69def get_pre_and_postprocessing( 

70 model: AnyModelDescr, 

71 *, 

72 dataset_for_initial_statistics: Iterable[Sample], 

73 keep_updating_initial_dataset_stats: bool = False, 

74 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

75) -> _ProcessingCallables: 

76 """Creates callables to apply pre- and postprocessing in-place to a sample""" 

77 

78 setup = setup_pre_and_postprocessing( 

79 model=model, 

80 dataset_for_initial_statistics=dataset_for_initial_statistics, 

81 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

82 fixed_dataset_stats=fixed_dataset_stats, 

83 ) 

84 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post)) 

85 

86 

87def setup_pre_and_postprocessing( 

88 model: AnyModelDescr, 

89 dataset_for_initial_statistics: Iterable[Sample], 

90 keep_updating_initial_dataset_stats: bool = False, 

91 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

92) -> PreAndPostprocessing: 

93 """Get pre- and postprocessing operators for a `model` description. 

94 

95 Used in `bioimageio.core.create_prediction_pipeline 

96 """ 

97 

98 prep = _get_described_procs(model.inputs) 

99 post = _get_described_procs(model.outputs) 

100 required = {m for p in chain(prep, post) for m in p.required_measures} 

101 missing_dataset_stats = { 

102 m 

103 for m in required 

104 if fixed_dataset_stats is None or m not in fixed_dataset_stats 

105 } 

106 if missing_dataset_stats: 

107 initial_stats_calc = StatsCalculator(missing_dataset_stats) 

108 for sample in dataset_for_initial_statistics: 

109 initial_stats_calc.update(sample) 

110 

111 initial_stats = initial_stats_calc.finalize() 

112 prep.insert( 

113 0, 

114 UpdateStats( 

115 StatsCalculator(required, initial_stats), 

116 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

117 ), 

118 ) 

119 

120 if fixed_dataset_stats: 

121 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

122 

123 return PreAndPostprocessing(prep, post) 

124 

125 

126class RequiredMeasures(NamedTuple): 

127 pre: Set[Measure] 

128 post: Set[Measure] 

129 

130 

131class RequiredDatasetMeasures(NamedTuple): 

132 pre: Set[DatasetMeasure] 

133 post: Set[DatasetMeasure] 

134 

135 

136class RequiredSampleMeasures(NamedTuple): 

137 pre: Set[SampleMeasure] 

138 post: Set[SampleMeasure] 

139 

140 

141def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures: 

142 pre = _get_described_procs(model.inputs) 

143 post = _get_described_procs(model.outputs) 

144 return RequiredMeasures( 

145 {m for proc in pre for m in proc.required_measures}, 

146 {m for proc in post for m in proc.required_measures}, 

147 ) 

148 

149 

150def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures: 

151 req = get_requried_measures(model) 

152 return RequiredDatasetMeasures( 

153 {m for m in req.pre if isinstance(m, DatasetMeasureBase)}, 

154 {m for m in req.post if isinstance(m, DatasetMeasureBase)}, 

155 ) 

156 

157 

158def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures: 

159 req = get_requried_measures(model) 

160 return RequiredSampleMeasures( 

161 {m for m in req.pre if isinstance(m, SampleMeasureBase)}, 

162 {m for m in req.post if isinstance(m, SampleMeasureBase)}, 

163 ) 

164 

165 

166def _get_described_procs( 

167 tensor_descrs: Iterable[TensorDescr], 

168) -> List[Processing]: 

169 tensor_descrs = list(tensor_descrs) 

170 all_output_ids = [ 

171 get_member_id(d) 

172 for d in tensor_descrs 

173 if isinstance(d, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)) 

174 ] 

175 

176 procs: List[Processing] = [] 

177 for t_descr in tensor_descrs: 

178 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

179 member_id = get_member_id(t_descr) 

180 procs.append( 

181 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) 

182 ) 

183 

184 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): 

185 for proc_d in t_descr.preprocessing: 

186 procs.append(get_proc(proc_d, t_descr)) 

187 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)): 

188 for proc_d in t_descr.postprocessing: 

189 if ( 

190 _CustomPostprocessingDescr is not None 

191 and isinstance(proc_d, _CustomPostprocessingDescr) 

192 and isinstance(t_descr, v0_5.OutputTensorDescr) 

193 ): 

194 procs.append( 

195 CustomPostprocessing.from_proc_descr( 

196 proc_d, t_descr, all_output_ids 

197 ) 

198 ) 

199 else: 

200 procs.append(get_proc(proc_d, t_descr)) 

201 else: 

202 assert_never(t_descr) 

203 

204 if isinstance( 

205 t_descr, 

206 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)), 

207 ): 

208 if len(procs) == 1: 

209 # remove initial ensure_dtype if there are no other proccessing steps 

210 assert isinstance(procs[0], EnsureDtype) 

211 procs = [] 

212 

213 # ensure 0.4 models get float32 input 

214 # which has been the implicit assumption for 0.4 

215 member_id = get_member_id(t_descr) 

216 procs.append( 

217 EnsureDtype(input=member_id, output=member_id, dtype="float32") 

218 ) 

219 

220 return procs