Coverage for bioimageio/core/_prediction_pipeline.py: 89%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import warnings 

2from types import MappingProxyType 

3from typing import ( 

4 Any, 

5 Iterable, 

6 List, 

7 Mapping, 

8 Optional, 

9 Sequence, 

10 Tuple, 

11 TypeVar, 

12 Union, 

13) 

14 

15from tqdm import tqdm 

16 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18from bioimageio.spec.model.v0_5 import WeightsFormat 

19 

20from ._op_base import BlockedOperator 

21from .axis import AxisId, PerAxis 

22from .common import Halo, MemberId, PerMember, SampleId 

23from .digest_spec import ( 

24 get_block_transform, 

25 get_input_halo, 

26 get_member_ids, 

27) 

28from .model_adapters import ModelAdapter, create_model_adapter 

29from .model_adapters import get_weight_formats as get_weight_formats 

30from .proc_ops import Processing 

31from .proc_setup import setup_pre_and_postprocessing 

32from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

33from .stat_measures import DatasetMeasure, MeasureValue, Stat 

34from .tensor import Tensor 

35 

36Predict_IO = TypeVar( 

37 "Predict_IO", 

38 Sample, 

39 Iterable[Sample], 

40) 

41 

42 

43class PredictionPipeline: 

44 """ 

45 Represents model computation including preprocessing and postprocessing 

46 Note: Ideally use the PredictionPipeline as a context manager 

47 """ 

48 

49 def __init__( 

50 self, 

51 *, 

52 name: str, 

53 model_description: AnyModelDescr, 

54 preprocessing: List[Processing], 

55 postprocessing: List[Processing], 

56 model_adapter: ModelAdapter, 

57 default_ns: Union[ 

58 v0_5.ParameterizedSize_N, 

59 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

60 ] = 10, 

61 default_batch_size: int = 1, 

62 ) -> None: 

63 super().__init__() 

64 if model_description.run_mode: 

65 warnings.warn( 

66 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 

67 ) 

68 

69 self.name = name 

70 self._preprocessing = preprocessing 

71 self._postprocessing = postprocessing 

72 

73 self.model_description = model_description 

74 if isinstance(model_description, v0_4.ModelDescr): 

75 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 

76 self._block_transform = None 

77 else: 

78 default_output_halo = { 

79 t.id: { 

80 a.id: Halo(a.halo, a.halo) 

81 for a in t.axes 

82 if isinstance(a, v0_5.WithHalo) 

83 } 

84 for t in model_description.outputs 

85 } 

86 self._default_input_halo = get_input_halo( 

87 model_description, default_output_halo 

88 ) 

89 self._block_transform = get_block_transform(model_description) 

90 

91 self._default_ns = default_ns 

92 self._default_batch_size = default_batch_size 

93 

94 self._input_ids = get_member_ids(model_description.inputs) 

95 self._output_ids = get_member_ids(model_description.outputs) 

96 

97 self._adapter: ModelAdapter = model_adapter 

98 

99 def __enter__(self): 

100 self.load() 

101 return self 

102 

103 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 

104 self.unload() 

105 return False 

106 

107 def predict_sample_block( 

108 self, 

109 sample_block: SampleBlockWithOrigin, 

110 skip_preprocessing: bool = False, 

111 skip_postprocessing: bool = False, 

112 ) -> SampleBlock: 

113 if isinstance(self.model_description, v0_4.ModelDescr): 

114 raise NotImplementedError( 

115 f"predict_sample_block not implemented for model {self.model_description.format_version}" 

116 ) 

117 else: 

118 assert self._block_transform is not None 

119 

120 if not skip_preprocessing: 

121 self.apply_preprocessing(sample_block) 

122 

123 output_meta = sample_block.get_transformed_meta(self._block_transform) 

124 output = output_meta.with_data( 

125 { 

126 tid: out 

127 for tid, out in zip( 

128 self._output_ids, 

129 self._adapter.forward( 

130 *(sample_block.members.get(t) for t in self._input_ids) 

131 ), 

132 ) 

133 if out is not None 

134 }, 

135 stat=sample_block.stat, 

136 ) 

137 if not skip_postprocessing: 

138 self.apply_postprocessing(output) 

139 

140 return output 

141 

142 def predict_sample_without_blocking( 

143 self, 

144 sample: Sample, 

145 skip_preprocessing: bool = False, 

146 skip_postprocessing: bool = False, 

147 ) -> Sample: 

148 """predict a sample. 

149 The sample's tensor shapes have to match the model's input tensor description. 

150 If that is not the case, consider `predict_sample_with_blocking`""" 

151 

152 if not skip_preprocessing: 

153 self.apply_preprocessing(sample) 

154 

155 output = Sample( 

156 members={ 

157 out_id: out 

158 for out_id, out in zip( 

159 self._output_ids, 

160 self._adapter.forward( 

161 *(sample.members.get(in_id) for in_id in self._input_ids) 

162 ), 

163 ) 

164 if out is not None 

165 }, 

166 stat=sample.stat, 

167 id=sample.id, 

168 ) 

169 if not skip_postprocessing: 

170 self.apply_postprocessing(output) 

171 

172 return output 

173 

174 def get_output_sample_id(self, input_sample_id: SampleId): 

175 warnings.warn( 

176 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 

177 + " removed soon. Output sample id is equal to input sample id, hence this" 

178 + " function is not needed." 

179 ) 

180 return input_sample_id 

181 

182 def predict_sample_with_fixed_blocking( 

183 self, 

184 sample: Sample, 

185 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 

186 *, 

187 skip_preprocessing: bool = False, 

188 skip_postprocessing: bool = False, 

189 ) -> Sample: 

190 if not skip_preprocessing: 

191 self.apply_preprocessing(sample) 

192 

193 n_blocks, input_blocks = sample.split_into_blocks( 

194 input_block_shape, 

195 halo=self._default_input_halo, 

196 pad_mode="reflect", 

197 ) 

198 input_blocks = list(input_blocks) 

199 predicted_blocks: List[SampleBlock] = [] 

200 for b in tqdm( 

201 input_blocks, 

202 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", 

203 unit="block", 

204 unit_divisor=1, 

205 total=n_blocks, 

206 ): 

207 predicted_blocks.append( 

208 self.predict_sample_block( 

209 b, skip_preprocessing=True, skip_postprocessing=True 

210 ) 

211 ) 

212 

213 predicted_sample = Sample.from_blocks(predicted_blocks) 

214 if not skip_postprocessing: 

215 self.apply_postprocessing(predicted_sample) 

216 

217 return predicted_sample 

218 

219 def predict_sample_with_blocking( 

220 self, 

221 sample: Sample, 

222 skip_preprocessing: bool = False, 

223 skip_postprocessing: bool = False, 

224 ns: Optional[ 

225 Union[ 

226 v0_5.ParameterizedSize_N, 

227 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

228 ] 

229 ] = None, 

230 batch_size: Optional[int] = None, 

231 ) -> Sample: 

232 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 

233 

234 if isinstance(self.model_description, v0_4.ModelDescr): 

235 raise NotImplementedError( 

236 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

237 + f" {self.model_description.name}." 

238 + " Consider using `predict_sample_with_fixed_blocking`" 

239 ) 

240 

241 ns = ns or self._default_ns 

242 if isinstance(ns, int): 

243 ns = { 

244 (ipt.id, a.id): ns 

245 for ipt in self.model_description.inputs 

246 for a in ipt.axes 

247 if isinstance(a.size, v0_5.ParameterizedSize) 

248 } 

249 input_block_shape = self.model_description.get_tensor_sizes( 

250 ns, batch_size or self._default_batch_size 

251 ).inputs 

252 

253 return self.predict_sample_with_fixed_blocking( 

254 sample, 

255 input_block_shape=input_block_shape, 

256 skip_preprocessing=skip_preprocessing, 

257 skip_postprocessing=skip_postprocessing, 

258 ) 

259 

260 # def predict( 

261 # self, 

262 # inputs: Predict_IO, 

263 # skip_preprocessing: bool = False, 

264 # skip_postprocessing: bool = False, 

265 # ) -> Predict_IO: 

266 # """Run model prediction **including** pre/postprocessing.""" 

267 

268 # if isinstance(inputs, Sample): 

269 # return self.predict_sample_with_blocking( 

270 # inputs, 

271 # skip_preprocessing=skip_preprocessing, 

272 # skip_postprocessing=skip_postprocessing, 

273 # ) 

274 # elif isinstance(inputs, collections.abc.Iterable): 

275 # return ( 

276 # self.predict( 

277 # ipt, 

278 # skip_preprocessing=skip_preprocessing, 

279 # skip_postprocessing=skip_postprocessing, 

280 # ) 

281 # for ipt in inputs 

282 # ) 

283 # else: 

284 # assert_never(inputs) 

285 

286 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

287 """apply preprocessing in-place, also updates sample stats""" 

288 for op in self._preprocessing: 

289 op(sample) 

290 

291 def apply_postprocessing( 

292 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 

293 ) -> None: 

294 """apply postprocessing in-place, also updates samples stats""" 

295 for op in self._postprocessing: 

296 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 

297 op(sample) 

298 elif not isinstance(op, BlockedOperator): 

299 raise NotImplementedError( 

300 "block wise update of output statistics not yet implemented" 

301 ) 

302 else: 

303 op(sample) 

304 

305 def load(self): 

306 """ 

307 optional step: load model onto devices before calling forward if not using it as context manager 

308 """ 

309 pass 

310 

311 def unload(self): 

312 """ 

313 free any device memory in use 

314 """ 

315 self._adapter.unload() 

316 

317 

318def create_prediction_pipeline( 

319 bioimageio_model: AnyModelDescr, 

320 *, 

321 devices: Optional[Sequence[str]] = None, 

322 weight_format: Optional[WeightsFormat] = None, 

323 weights_format: Optional[WeightsFormat] = None, 

324 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 

325 keep_updating_initial_dataset_statistics: bool = False, 

326 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( 

327 {} 

328 ), 

329 model_adapter: Optional[ModelAdapter] = None, 

330 ns: Union[ 

331 v0_5.ParameterizedSize_N, 

332 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

333 ] = 10, 

334 **deprecated_kwargs: Any, 

335) -> PredictionPipeline: 

336 """ 

337 Creates prediction pipeline which includes: 

338 * computation of input statistics 

339 * preprocessing 

340 * model prediction 

341 * computation of output statistics 

342 * postprocessing 

343 """ 

344 weights_format = weight_format or weights_format 

345 del weight_format 

346 if deprecated_kwargs: 

347 warnings.warn( 

348 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 

349 ) 

350 

351 model_adapter = model_adapter or create_model_adapter( 

352 model_description=bioimageio_model, 

353 devices=devices, 

354 weight_format_priority_order=weights_format and (weights_format,), 

355 ) 

356 

357 input_ids = get_member_ids(bioimageio_model.inputs) 

358 

359 def dataset(): 

360 common_stat: Stat = {} 

361 for i, x in enumerate(dataset_for_initial_statistics): 

362 if isinstance(x, Sample): 

363 yield x 

364 else: 

365 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 

366 

367 preprocessing, postprocessing = setup_pre_and_postprocessing( 

368 bioimageio_model, 

369 dataset(), 

370 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 

371 fixed_dataset_stats=fixed_dataset_statistics, 

372 ) 

373 

374 return PredictionPipeline( 

375 name=bioimageio_model.name, 

376 model_description=bioimageio_model, 

377 model_adapter=model_adapter, 

378 preprocessing=preprocessing, 

379 postprocessing=postprocessing, 

380 default_ns=ns, 

381 )