Coverage for bioimageio/core/prediction.py: 59%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import collections.abc 

2from pathlib import Path 

3from typing import ( 

4 Any, 

5 Hashable, 

6 Iterable, 

7 Iterator, 

8 Mapping, 

9 Optional, 

10 Tuple, 

11 Union, 

12) 

13 

14import xarray as xr 

15from loguru import logger 

16from numpy.typing import NDArray 

17from tqdm import tqdm 

18 

19from bioimageio.spec import load_description 

20from bioimageio.spec.common import PermissiveFileSource 

21from bioimageio.spec.model import v0_4, v0_5 

22 

23from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 

24from .axis import AxisId 

25from .common import MemberId, PerMember 

26from .digest_spec import create_sample_for_model 

27from .io import save_sample 

28from .sample import Sample 

29from .tensor import Tensor 

30 

31 

32def predict( 

33 *, 

34 model: Union[ 

35 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

36 ], 

37 inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], 

38 sample_id: Hashable = "sample", 

39 blocksize_parameter: Optional[ 

40 Union[ 

41 v0_5.ParameterizedSize_N, 

42 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

43 ] 

44 ] = None, 

45 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 

46 skip_preprocessing: bool = False, 

47 skip_postprocessing: bool = False, 

48 save_output_path: Optional[Union[Path, str]] = None, 

49) -> Sample: 

50 """Run prediction for a single set of input(s) with a bioimage.io model 

51 

52 Args: 

53 model: model to predict with. 

54 May be given as RDF source, model description or prediction pipeline. 

55 inputs: the input sample or the named input(s) for this model as a dictionary 

56 sample_id: the sample id. 

57 blocksize_parameter: (optional) tile the input into blocks parametrized by 

58 blocksize according to any parametrized axis sizes defined in the model RDF. 

59 Note: For a predetermined, fixed block shape use `input_block_shape` 

60 input_block_shape: (optional) tile the input sample tensors into blocks. 

61 Note: For a parameterized block shape, not dealing with the exact block shape, 

62 use `blocksize_parameter`. 

63 skip_preprocessing: flag to skip the model's preprocessing 

64 skip_postprocessing: flag to skip the model's postprocessing 

65 save_output_path: A path with `{member_id}` `{sample_id}` in it 

66 to save the output to. 

67 """ 

68 if save_output_path is not None: 

69 if "{member_id}" not in str(save_output_path): 

70 raise ValueError( 

71 f"Missing `{ member_id} ` in save_output_path={save_output_path}" 

72 ) 

73 

74 if isinstance(model, PredictionPipeline): 

75 pp = model 

76 else: 

77 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

78 loaded = load_description(model) 

79 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

80 raise ValueError(f"expected model description, but got {loaded}") 

81 model = loaded 

82 

83 pp = create_prediction_pipeline(model) 

84 

85 if isinstance(inputs, Sample): 

86 sample = inputs 

87 else: 

88 sample = create_sample_for_model( 

89 pp.model_description, inputs=inputs, sample_id=sample_id 

90 ) 

91 

92 if input_block_shape is not None: 

93 if blocksize_parameter is not None: 

94 logger.warning( 

95 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 

96 blocksize_parameter, 

97 input_block_shape, 

98 ) 

99 

100 output = pp.predict_sample_with_fixed_blocking( 

101 sample, 

102 input_block_shape=input_block_shape, 

103 skip_preprocessing=skip_preprocessing, 

104 skip_postprocessing=skip_postprocessing, 

105 ) 

106 elif blocksize_parameter is not None: 

107 output = pp.predict_sample_with_blocking( 

108 sample, 

109 skip_preprocessing=skip_preprocessing, 

110 skip_postprocessing=skip_postprocessing, 

111 ns=blocksize_parameter, 

112 ) 

113 else: 

114 output = pp.predict_sample_without_blocking( 

115 sample, 

116 skip_preprocessing=skip_preprocessing, 

117 skip_postprocessing=skip_postprocessing, 

118 ) 

119 if save_output_path: 

120 save_sample(save_output_path, output) 

121 

122 return output 

123 

124 

125def predict_many( 

126 *, 

127 model: Union[ 

128 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 

129 ], 

130 inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], 

131 sample_id: str = "sample{i:03}", 

132 blocksize_parameter: Optional[ 

133 Union[ 

134 v0_5.ParameterizedSize_N, 

135 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

136 ] 

137 ] = None, 

138 skip_preprocessing: bool = False, 

139 skip_postprocessing: bool = False, 

140 save_output_path: Optional[Union[Path, str]] = None, 

141) -> Iterator[Sample]: 

142 """Run prediction for a multiple sets of inputs with a bioimage.io model 

143 

144 Args: 

145 model: model to predict with. 

146 May be given as RDF source, model description or prediction pipeline. 

147 inputs: An iterable of the named input(s) for this model as a dictionary. 

148 sample_id: the sample id. 

149 note: `{i}` will be formatted as the i-th sample. 

150 If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended. 

151 blocksize_parameter: (optional) tile the input into blocks parametrized by 

152 blocksize according to any parametrized axis sizes defined in the model RDF 

153 skip_preprocessing: flag to skip the model's preprocessing 

154 skip_postprocessing: flag to skip the model's postprocessing 

155 save_output_path: A path with `{member_id}` `{sample_id}` in it 

156 to save the output to. 

157 """ 

158 if save_output_path is not None: 

159 if "{member_id}" not in str(save_output_path): 

160 raise ValueError( 

161 f"Missing `{ member_id} ` in save_output_path={save_output_path}" 

162 ) 

163 

164 if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str( 

165 save_output_path 

166 ): 

167 raise ValueError( 

168 f"Missing `{ sample_id} ` in save_output_path={save_output_path}" 

169 ) 

170 

171 if isinstance(model, PredictionPipeline): 

172 pp = model 

173 else: 

174 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 

175 loaded = load_description(model) 

176 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 

177 raise ValueError(f"expected model description, but got {loaded}") 

178 model = loaded 

179 

180 pp = create_prediction_pipeline(model) 

181 

182 if not isinstance(inputs, collections.abc.Mapping): 

183 sample_id = str(sample_id) 

184 if "{i}" not in sample_id and "{i:" not in sample_id: 

185 sample_id += "{i:03}" 

186 

187 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 

188 

189 for i, ipts in tqdm(enumerate(inputs), total=total): 

190 yield predict( 

191 model=pp, 

192 inputs=ipts, 

193 sample_id=sample_id.format(i=i), 

194 blocksize_parameter=blocksize_parameter, 

195 skip_preprocessing=skip_preprocessing, 

196 skip_postprocessing=skip_postprocessing, 

197 save_output_path=save_output_path, 

198 )