Coverage for src/bioimageio/core/prediction.py: 63%

59 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import collections.abc 

2from pathlib import Path 

3from typing import ( 

4 Hashable, 

5 Iterable, 

6 Iterator, 

7 Mapping, 

8 Optional, 

9 Tuple, 

10 Union, 

11) 

12 

13from loguru import logger 

14from tqdm import tqdm 

15 

16from bioimageio.spec import load_description 

17from bioimageio.spec.common import PermissiveFileSource 

18from bioimageio.spec.model import v0_4, v0_5 

19 

20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 

21from .axis import AxisId 

22from .common import BlocksizeParameter, MemberId, PerMember 

23from .digest_spec import TensorSource, create_sample_for_model, get_member_id 

24from .io import save_sample 

25from .sample import Sample 

26 

27 

28def predict( 

29 *, 

30 model: Union[ 

31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

32 ], 

33 inputs: Union[Sample, PerMember[TensorSource], TensorSource], 

34 sample_id: Hashable = "sample", 

35 blocksize_parameter: Optional[BlocksizeParameter] = None, 

36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 

37 skip_preprocessing: bool = False, 

38 skip_postprocessing: bool = False, 

39 save_output_path: Optional[Union[Path, str]] = None, 

40) -> Sample: 

41 """Run prediction for a single set of input(s) with a bioimage.io model 

42 

43 Args: 

44 model: Model to predict with. 

45 May be given as RDF source, model description or prediction pipeline. 

46 inputs: the input sample or the named input(s) for this model as a dictionary 

47 sample_id: the sample id. 

48 The **sample_id** is used to format **save_output_path** 

49 and to distinguish sample specific log messages. 

50 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

51 **blocksize_parameter** according to any parametrized axis sizes defined 

52 by the **model**. 

53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details. 

54 Note: For a predetermined, fixed block shape use **input_block_shape**. 

55 input_block_shape: (optional) Tile the input sample tensors into blocks. 

56 Note: Use **blocksize_parameter** for a parameterized block shape to 

57 run prediction independent of the exact block shape. 

58 skip_preprocessing: Flag to skip the model's preprocessing. 

59 skip_postprocessing: Flag to skip the model's postprocessing. 

60 save_output_path: A path with to save the output to. M 

61 Must contain: 

62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors 

63 May contain: 

64 - `{sample_id}` to avoid overwriting recurrent calls 

65 """ 

66 if isinstance(model, PredictionPipeline): 

67 pp = model 

68 model = pp.model_descr 

69 else: 

70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

71 loaded = load_description(model) 

72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

73 raise ValueError(f"expected model description, but got {loaded}") 

74 model = loaded 

75 

76 pp = create_prediction_pipeline( 

77 model, 

78 fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {}, 

79 ) 

80 

81 with pp: 

82 model = pp.model_descr 

83 if save_output_path is not None: 

84 if ( 

85 "{output_id}" not in str(save_output_path) 

86 and "{member_id}" not in str(save_output_path) 

87 and len(model.outputs) > 1 

88 ): 

89 raise ValueError( 

90 f"Missing `{{output_id}}` in save_output_path={save_output_path} to " 

91 + "distinguish model outputs " 

92 + str([get_member_id(d) for d in model.outputs]) 

93 ) 

94 

95 if isinstance(inputs, Sample): 

96 sample = inputs 

97 else: 

98 sample = create_sample_for_model( 

99 pp.model_descr, inputs=inputs, sample_id=sample_id 

100 ) 

101 

102 if input_block_shape is not None: 

103 if blocksize_parameter is not None: 

104 logger.warning( 

105 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 

106 blocksize_parameter, 

107 input_block_shape, 

108 ) 

109 

110 output = pp.predict_sample_with_fixed_blocking( 

111 sample, 

112 input_block_shape=input_block_shape, 

113 skip_preprocessing=skip_preprocessing, 

114 skip_postprocessing=skip_postprocessing, 

115 ) 

116 elif blocksize_parameter is not None: 

117 output = pp.predict_sample_with_blocking( 

118 sample, 

119 skip_preprocessing=skip_preprocessing, 

120 skip_postprocessing=skip_postprocessing, 

121 ns=blocksize_parameter, 

122 ) 

123 else: 

124 output = pp.predict_sample_without_blocking( 

125 sample, 

126 skip_preprocessing=skip_preprocessing, 

127 skip_postprocessing=skip_postprocessing, 

128 ) 

129 if save_output_path: 

130 save_sample(save_output_path, output) 

131 

132 return output 

133 

134 

135def predict_many( 

136 *, 

137 model: Union[ 

138 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

139 ], 

140 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]], 

141 sample_id: str = "sample{i:03}", 

142 blocksize_parameter: Optional[ 

143 Union[ 

144 v0_5.ParameterizedSize_N, 

145 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

146 ] 

147 ] = None, 

148 skip_preprocessing: bool = False, 

149 skip_postprocessing: bool = False, 

150 save_output_path: Optional[Union[Path, str]] = None, 

151) -> Iterator[Sample]: 

152 """Run prediction for a multiple sets of inputs with a bioimage.io model 

153 

154 Args: 

155 model: Model to predict with. 

156 May be given as RDF source, model description or prediction pipeline. 

157 inputs: An iterable of the named input(s) for this model as a dictionary. 

158 sample_id: The sample id. 

159 note: `{i}` will be formatted as the i-th sample. 

160 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}` 

161 is appended. 

162 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

163 blocksize according to any parametrized axis sizes defined in the model RDF. 

164 skip_preprocessing: Flag to skip the model's preprocessing. 

165 skip_postprocessing: Flag to skip the model's postprocessing. 

166 save_output_path: A path to save the output to. 

167 Must contain: 

168 - `{sample_id}` to differentiate predicted samples 

169 - `{output_id}` (or `{member_id}`) if the model has multiple outputs 

170 """ 

171 if save_output_path is not None and "{sample_id}" not in str(save_output_path): 

172 raise ValueError( 

173 f"Missing `{{sample_id}}` in save_output_path={save_output_path}" 

174 + " to differentiate predicted samples." 

175 ) 

176 

177 if isinstance(model, PredictionPipeline): 

178 pp = model 

179 else: 

180 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

181 loaded = load_description(model) 

182 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

183 raise ValueError(f"expected model description, but got {loaded}") 

184 model = loaded 

185 

186 pp = create_prediction_pipeline(model) 

187 

188 if not isinstance(inputs, collections.abc.Mapping): 

189 if "{i}" not in sample_id and "{i:" not in sample_id: 

190 sample_id += "{i:03}" 

191 

192 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 

193 

194 for i, ipts in tqdm(enumerate(inputs), total=total): 

195 yield predict( 

196 model=pp, 

197 inputs=ipts, 

198 sample_id=sample_id.format(i=i), 

199 blocksize_parameter=blocksize_parameter, 

200 skip_preprocessing=skip_preprocessing, 

201 skip_postprocessing=skip_postprocessing, 

202 save_output_path=save_output_path, 

203 )