Coverage for src / bioimageio / core / proc_setup.py: 96%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 22:06 +0000

1from itertools import chain 

2from typing import ( 

3 Callable, 

4 Iterable, 

5 List, 

6 Mapping, 

7 NamedTuple, 

8 Optional, 

9 Sequence, 

10 Set, 

11 Union, 

12) 

13 

14from typing_extensions import assert_never 

15 

16from bioimageio.core.digest_spec import get_member_id 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18 

19from .proc_ops import ( 

20 AddKnownDatasetStats, 

21 EnsureDtype, 

22 Processing, 

23 UpdateStats, 

24 get_proc, 

25) 

26from .sample import Sample 

27from .stat_calculators import StatsCalculator 

28from .stat_measures import ( 

29 DatasetMeasure, 

30 DatasetMeasureBase, 

31 Measure, 

32 MeasureValue, 

33 SampleMeasure, 

34 SampleMeasureBase, 

35) 

36 

37TensorDescr = Union[ 

38 v0_4.InputTensorDescr, 

39 v0_4.OutputTensorDescr, 

40 v0_5.InputTensorDescr, 

41 v0_5.OutputTensorDescr, 

42] 

43 

44 

45class PreAndPostprocessing(NamedTuple): 

46 pre: List[Processing] 

47 post: List[Processing] 

48 

49 

50class _ProcessingCallables(NamedTuple): 

51 pre: Callable[[Sample], None] 

52 post: Callable[[Sample], None] 

53 

54 

55class _ApplyProcs: 

56 def __init__(self, procs: Sequence[Processing]): 

57 super().__init__() 

58 self._procs = procs 

59 

60 def __call__(self, sample: Sample) -> None: 

61 for op in self._procs: 

62 op(sample) 

63 

64 

65def get_pre_and_postprocessing( 

66 model: AnyModelDescr, 

67 *, 

68 dataset_for_initial_statistics: Iterable[Sample], 

69 keep_updating_initial_dataset_stats: bool = False, 

70 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

71) -> _ProcessingCallables: 

72 """Creates callables to apply pre- and postprocessing in-place to a sample""" 

73 

74 setup = setup_pre_and_postprocessing( 

75 model=model, 

76 dataset_for_initial_statistics=dataset_for_initial_statistics, 

77 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

78 fixed_dataset_stats=fixed_dataset_stats, 

79 ) 

80 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post)) 

81 

82 

83def setup_pre_and_postprocessing( 

84 model: AnyModelDescr, 

85 dataset_for_initial_statistics: Iterable[Sample], 

86 keep_updating_initial_dataset_stats: bool = False, 

87 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, 

88) -> PreAndPostprocessing: 

89 """Get pre- and postprocessing operators for a `model` description. 

90 

91 Used in `bioimageio.core.create_prediction_pipeline 

92 """ 

93 

94 prep = _get_described_procs(model.inputs) 

95 post = _get_described_procs(model.outputs) 

96 required = {m for p in chain(prep, post) for m in p.required_measures} 

97 missing_dataset_stats = { 

98 m 

99 for m in required 

100 if fixed_dataset_stats is None or m not in fixed_dataset_stats 

101 } 

102 if missing_dataset_stats: 

103 initial_stats_calc = StatsCalculator(missing_dataset_stats) 

104 for sample in dataset_for_initial_statistics: 

105 initial_stats_calc.update(sample) 

106 

107 initial_stats = initial_stats_calc.finalize() 

108 prep.insert( 

109 0, 

110 UpdateStats( 

111 StatsCalculator(required, initial_stats), 

112 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, 

113 ), 

114 ) 

115 

116 if fixed_dataset_stats: 

117 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) 

118 

119 return PreAndPostprocessing(prep, post) 

120 

121 

122class RequiredMeasures(NamedTuple): 

123 pre: Set[Measure] 

124 post: Set[Measure] 

125 

126 

127class RequiredDatasetMeasures(NamedTuple): 

128 pre: Set[DatasetMeasure] 

129 post: Set[DatasetMeasure] 

130 

131 

132class RequiredSampleMeasures(NamedTuple): 

133 pre: Set[SampleMeasure] 

134 post: Set[SampleMeasure] 

135 

136 

137def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures: 

138 pre = _get_described_procs(model.inputs) 

139 post = _get_described_procs(model.outputs) 

140 return RequiredMeasures( 

141 {m for proc in pre for m in proc.required_measures}, 

142 {m for proc in post for m in proc.required_measures}, 

143 ) 

144 

145 

146def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures: 

147 req = get_requried_measures(model) 

148 return RequiredDatasetMeasures( 

149 {m for m in req.pre if isinstance(m, DatasetMeasureBase)}, 

150 {m for m in req.post if isinstance(m, DatasetMeasureBase)}, 

151 ) 

152 

153 

154def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures: 

155 req = get_requried_measures(model) 

156 return RequiredSampleMeasures( 

157 {m for m in req.pre if isinstance(m, SampleMeasureBase)}, 

158 {m for m in req.post if isinstance(m, SampleMeasureBase)}, 

159 ) 

160 

161 

162def _get_described_procs( 

163 tensor_descrs: Iterable[TensorDescr], 

164) -> List[Processing]: 

165 procs: List[Processing] = [] 

166 for t_descr in tensor_descrs: 

167 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

168 member_id = get_member_id(t_descr) 

169 procs.append( 

170 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) 

171 ) 

172 

173 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): 

174 for proc_d in t_descr.preprocessing: 

175 procs.append(get_proc(proc_d, t_descr)) 

176 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)): 

177 for proc_d in t_descr.postprocessing: 

178 procs.append(get_proc(proc_d, t_descr)) 

179 else: 

180 assert_never(t_descr) 

181 

182 if isinstance( 

183 t_descr, 

184 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)), 

185 ): 

186 if len(procs) == 1: 

187 # remove initial ensure_dtype if there are no other proccessing steps 

188 assert isinstance(procs[0], EnsureDtype) 

189 procs = [] 

190 

191 # ensure 0.4 models get float32 input 

192 # which has been the implicit assumption for 0.4 

193 member_id = get_member_id(t_descr) 

194 procs.append( 

195 EnsureDtype(input=member_id, output=member_id, dtype="float32") 

196 ) 

197 

198 return procs