Coverage for src / bioimageio / core / prediction.py: 61%

57 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import collections.abc 

2from pathlib import Path 

3from typing import ( 

4 Hashable, 

5 Iterable, 

6 Iterator, 

7 Mapping, 

8 Optional, 

9 Tuple, 

10 Union, 

11) 

12 

13from loguru import logger 

14from tqdm import tqdm 

15 

16from bioimageio.spec import load_description 

17from bioimageio.spec.common import PermissiveFileSource 

18from bioimageio.spec.model import v0_4, v0_5 

19 

20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 

21from .axis import AxisId 

22from .common import BlocksizeParameter, MemberId, PerMember 

23from .digest_spec import TensorSource, create_sample_for_model, get_member_id 

24from .io import save_sample 

25from .sample import Sample 

26 

27 

28def predict( 

29 *, 

30 model: Union[ 

31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

32 ], 

33 inputs: Union[Sample, PerMember[TensorSource], TensorSource], 

34 sample_id: Hashable = "sample", 

35 blocksize_parameter: Optional[BlocksizeParameter] = None, 

36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 

37 skip_preprocessing: bool = False, 

38 skip_postprocessing: bool = False, 

39 save_output_path: Optional[Union[Path, str]] = None, 

40) -> Sample: 

41 """Run prediction for a single set of input(s) with a bioimage.io model 

42 

43 Args: 

44 model: Model to predict with. 

45 May be given as RDF source, model description or prediction pipeline. 

46 inputs: the input sample or the named input(s) for this model as a dictionary 

47 sample_id: the sample id. 

48 The **sample_id** is used to format **save_output_path** 

49 and to distinguish sample specific log messages. 

50 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

51 **blocksize_parameter** according to any parametrized axis sizes defined 

52 by the **model**. 

53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details. 

54 Note: For a predetermined, fixed block shape use **input_block_shape**. 

55 input_block_shape: (optional) Tile the input sample tensors into blocks. 

56 Note: Use **blocksize_parameter** for a parameterized block shape to 

57 run prediction independent of the exact block shape. 

58 skip_preprocessing: Flag to skip the model's preprocessing. 

59 skip_postprocessing: Flag to skip the model's postprocessing. 

60 save_output_path: A path with to save the output to. M 

61 Must contain: 

62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors 

63 May contain: 

64 - `{sample_id}` to avoid overwriting recurrent calls 

65 """ 

66 if isinstance(model, PredictionPipeline): 

67 pp = model 

68 model = pp.model_description 

69 else: 

70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

71 loaded = load_description(model) 

72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

73 raise ValueError(f"expected model description, but got {loaded}") 

74 model = loaded 

75 

76 pp = create_prediction_pipeline( 

77 model, 

78 fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {}, 

79 ) 

80 

81 if save_output_path is not None: 

82 if ( 

83 "{output_id}" not in str(save_output_path) 

84 and "{member_id}" not in str(save_output_path) 

85 and len(model.outputs) > 1 

86 ): 

87 raise ValueError( 

88 f"Missing `{{output_id}}` in save_output_path={save_output_path} to " 

89 + "distinguish model outputs " 

90 + str([get_member_id(d) for d in model.outputs]) 

91 ) 

92 

93 if isinstance(inputs, Sample): 

94 sample = inputs 

95 else: 

96 sample = create_sample_for_model( 

97 pp.model_description, inputs=inputs, sample_id=sample_id 

98 ) 

99 

100 if input_block_shape is not None: 

101 if blocksize_parameter is not None: 

102 logger.warning( 

103 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 

104 blocksize_parameter, 

105 input_block_shape, 

106 ) 

107 

108 output = pp.predict_sample_with_fixed_blocking( 

109 sample, 

110 input_block_shape=input_block_shape, 

111 skip_preprocessing=skip_preprocessing, 

112 skip_postprocessing=skip_postprocessing, 

113 ) 

114 elif blocksize_parameter is not None: 

115 output = pp.predict_sample_with_blocking( 

116 sample, 

117 skip_preprocessing=skip_preprocessing, 

118 skip_postprocessing=skip_postprocessing, 

119 ns=blocksize_parameter, 

120 ) 

121 else: 

122 output = pp.predict_sample_without_blocking( 

123 sample, 

124 skip_preprocessing=skip_preprocessing, 

125 skip_postprocessing=skip_postprocessing, 

126 ) 

127 if save_output_path: 

128 save_sample(save_output_path, output) 

129 

130 return output 

131 

132 

133def predict_many( 

134 *, 

135 model: Union[ 

136 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

137 ], 

138 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]], 

139 sample_id: str = "sample{i:03}", 

140 blocksize_parameter: Optional[ 

141 Union[ 

142 v0_5.ParameterizedSize_N, 

143 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

144 ] 

145 ] = None, 

146 skip_preprocessing: bool = False, 

147 skip_postprocessing: bool = False, 

148 save_output_path: Optional[Union[Path, str]] = None, 

149) -> Iterator[Sample]: 

150 """Run prediction for a multiple sets of inputs with a bioimage.io model 

151 

152 Args: 

153 model: Model to predict with. 

154 May be given as RDF source, model description or prediction pipeline. 

155 inputs: An iterable of the named input(s) for this model as a dictionary. 

156 sample_id: The sample id. 

157 note: `{i}` will be formatted as the i-th sample. 

158 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}` 

159 is appended. 

160 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

161 blocksize according to any parametrized axis sizes defined in the model RDF. 

162 skip_preprocessing: Flag to skip the model's preprocessing. 

163 skip_postprocessing: Flag to skip the model's postprocessing. 

164 save_output_path: A path to save the output to. 

165 Must contain: 

166 - `{sample_id}` to differentiate predicted samples 

167 - `{output_id}` (or `{member_id}`) if the model has multiple outputs 

168 """ 

169 if save_output_path is not None and "{sample_id}" not in str(save_output_path): 

170 raise ValueError( 

171 f"Missing `{{sample_id}}` in save_output_path={save_output_path}" 

172 + " to differentiate predicted samples." 

173 ) 

174 

175 if isinstance(model, PredictionPipeline): 

176 pp = model 

177 else: 

178 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

179 loaded = load_description(model) 

180 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

181 raise ValueError(f"expected model description, but got {loaded}") 

182 model = loaded 

183 

184 pp = create_prediction_pipeline(model) 

185 

186 if not isinstance(inputs, collections.abc.Mapping): 

187 if "{i}" not in sample_id and "{i:" not in sample_id: 

188 sample_id += "{i:03}" 

189 

190 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 

191 

192 for i, ipts in tqdm(enumerate(inputs), total=total): 

193 yield predict( 

194 model=pp, 

195 inputs=ipts, 

196 sample_id=sample_id.format(i=i), 

197 blocksize_parameter=blocksize_parameter, 

198 skip_preprocessing=skip_preprocessing, 

199 skip_postprocessing=skip_postprocessing, 

200 save_output_path=save_output_path, 

201 )