Coverage for bioimageio/core/prediction.py: 61%

57 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import collections.abc 

2from pathlib import Path 

3from typing import ( 

4 Hashable, 

5 Iterable, 

6 Iterator, 

7 Mapping, 

8 Optional, 

9 Tuple, 

10 Union, 

11) 

12 

13from loguru import logger 

14from tqdm import tqdm 

15 

16from bioimageio.spec import load_description 

17from bioimageio.spec.common import PermissiveFileSource 

18from bioimageio.spec.model import v0_4, v0_5 

19 

20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 

21from .axis import AxisId 

22from .common import BlocksizeParameter, MemberId, PerMember 

23from .digest_spec import TensorSource, create_sample_for_model, get_member_id 

24from .io import save_sample 

25from .sample import Sample 

26 

27 

28def predict( 

29 *, 

30 model: Union[ 

31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

32 ], 

33 inputs: Union[Sample, PerMember[TensorSource], TensorSource], 

34 sample_id: Hashable = "sample", 

35 blocksize_parameter: Optional[BlocksizeParameter] = None, 

36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 

37 skip_preprocessing: bool = False, 

38 skip_postprocessing: bool = False, 

39 save_output_path: Optional[Union[Path, str]] = None, 

40) -> Sample: 

41 """Run prediction for a single set of input(s) with a bioimage.io model 

42 

43 Args: 

44 model: Model to predict with. 

45 May be given as RDF source, model description or prediction pipeline. 

46 inputs: the input sample or the named input(s) for this model as a dictionary 

47 sample_id: the sample id. 

48 The **sample_id** is used to format **save_output_path** 

49 and to distinguish sample specific log messages. 

50 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

51 **blocksize_parameter** according to any parametrized axis sizes defined 

52 by the **model**. 

53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details. 

54 Note: For a predetermined, fixed block shape use **input_block_shape**. 

55 input_block_shape: (optional) Tile the input sample tensors into blocks. 

56 Note: Use **blocksize_parameter** for a parameterized block shape to 

57 run prediction independent of the exact block shape. 

58 skip_preprocessing: Flag to skip the model's preprocessing. 

59 skip_postprocessing: Flag to skip the model's postprocessing. 

60 save_output_path: A path with to save the output to. M 

61 Must contain: 

62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors 

63 May contain: 

64 - `{sample_id}` to avoid overwriting recurrent calls 

65 """ 

66 if isinstance(model, PredictionPipeline): 

67 pp = model 

68 model = pp.model_description 

69 else: 

70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

71 loaded = load_description(model) 

72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

73 raise ValueError(f"expected model description, but got {loaded}") 

74 model = loaded 

75 

76 pp = create_prediction_pipeline(model) 

77 

78 if save_output_path is not None: 

79 if ( 

80 "{output_id}" not in str(save_output_path) 

81 and "{member_id}" not in str(save_output_path) 

82 and len(model.outputs) > 1 

83 ): 

84 raise ValueError( 

85 f"Missing `{{output_id}}` in save_output_path={save_output_path} to " 

86 + "distinguish model outputs " 

87 + str([get_member_id(d) for d in model.outputs]) 

88 ) 

89 

90 if isinstance(inputs, Sample): 

91 sample = inputs 

92 else: 

93 sample = create_sample_for_model( 

94 pp.model_description, inputs=inputs, sample_id=sample_id 

95 ) 

96 

97 if input_block_shape is not None: 

98 if blocksize_parameter is not None: 

99 logger.warning( 

100 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 

101 blocksize_parameter, 

102 input_block_shape, 

103 ) 

104 

105 output = pp.predict_sample_with_fixed_blocking( 

106 sample, 

107 input_block_shape=input_block_shape, 

108 skip_preprocessing=skip_preprocessing, 

109 skip_postprocessing=skip_postprocessing, 

110 ) 

111 elif blocksize_parameter is not None: 

112 output = pp.predict_sample_with_blocking( 

113 sample, 

114 skip_preprocessing=skip_preprocessing, 

115 skip_postprocessing=skip_postprocessing, 

116 ns=blocksize_parameter, 

117 ) 

118 else: 

119 output = pp.predict_sample_without_blocking( 

120 sample, 

121 skip_preprocessing=skip_preprocessing, 

122 skip_postprocessing=skip_postprocessing, 

123 ) 

124 if save_output_path: 

125 save_sample(save_output_path, output) 

126 

127 return output 

128 

129 

130def predict_many( 

131 *, 

132 model: Union[ 

133 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

134 ], 

135 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]], 

136 sample_id: str = "sample{i:03}", 

137 blocksize_parameter: Optional[ 

138 Union[ 

139 v0_5.ParameterizedSize_N, 

140 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

141 ] 

142 ] = None, 

143 skip_preprocessing: bool = False, 

144 skip_postprocessing: bool = False, 

145 save_output_path: Optional[Union[Path, str]] = None, 

146) -> Iterator[Sample]: 

147 """Run prediction for a multiple sets of inputs with a bioimage.io model 

148 

149 Args: 

150 model: Model to predict with. 

151 May be given as RDF source, model description or prediction pipeline. 

152 inputs: An iterable of the named input(s) for this model as a dictionary. 

153 sample_id: The sample id. 

154 note: `{i}` will be formatted as the i-th sample. 

155 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}` 

156 is appended. 

157 blocksize_parameter: (optional) Tile the input into blocks parametrized by 

158 blocksize according to any parametrized axis sizes defined in the model RDF. 

159 skip_preprocessing: Flag to skip the model's preprocessing. 

160 skip_postprocessing: Flag to skip the model's postprocessing. 

161 save_output_path: A path to save the output to. 

162 Must contain: 

163 - `{sample_id}` to differentiate predicted samples 

164 - `{output_id}` (or `{member_id}`) if the model has multiple outputs 

165 """ 

166 if save_output_path is not None and "{sample_id}" not in str(save_output_path): 

167 raise ValueError( 

168 f"Missing `{{sample_id}}` in save_output_path={save_output_path}" 

169 + " to differentiate predicted samples." 

170 ) 

171 

172 if isinstance(model, PredictionPipeline): 

173 pp = model 

174 else: 

175 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

176 loaded = load_description(model) 

177 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

178 raise ValueError(f"expected model description, but got {loaded}") 

179 model = loaded 

180 

181 pp = create_prediction_pipeline(model) 

182 

183 if not isinstance(inputs, collections.abc.Mapping): 

184 if "{i}" not in sample_id and "{i:" not in sample_id: 

185 sample_id += "{i:03}" 

186 

187 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 

188 

189 for i, ipts in tqdm(enumerate(inputs), total=total): 

190 yield predict( 

191 model=pp, 

192 inputs=ipts, 

193 sample_id=sample_id.format(i=i), 

194 blocksize_parameter=blocksize_parameter, 

195 skip_preprocessing=skip_preprocessing, 

196 skip_postprocessing=skip_postprocessing, 

197 save_output_path=save_output_path, 

198 )