Coverage for bioimageio/core/_prediction_pipeline.py: 89%

122 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import warnings 

2from types import MappingProxyType 

3from typing import ( 

4 Any, 

5 Iterable, 

6 List, 

7 Mapping, 

8 Optional, 

9 Sequence, 

10 Tuple, 

11 TypeVar, 

12 Union, 

13) 

14 

15from loguru import logger 

16from tqdm import tqdm 

17 

18from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

19 

20from ._op_base import BlockedOperator 

21from .axis import AxisId, PerAxis 

22from .common import ( 

23 BlocksizeParameter, 

24 Halo, 

25 MemberId, 

26 PerMember, 

27 SampleId, 

28 SupportedWeightsFormat, 

29) 

30from .digest_spec import ( 

31 get_block_transform, 

32 get_input_halo, 

33 get_member_ids, 

34) 

35from .model_adapters import ModelAdapter, create_model_adapter 

36from .model_adapters import get_weight_formats as get_weight_formats 

37from .proc_ops import Processing 

38from .proc_setup import setup_pre_and_postprocessing 

39from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

40from .stat_measures import DatasetMeasure, MeasureValue, Stat 

41from .tensor import Tensor 

42 

43Predict_IO = TypeVar( 

44 "Predict_IO", 

45 Sample, 

46 Iterable[Sample], 

47) 

48 

49 

50class PredictionPipeline: 

51 """ 

52 Represents model computation including preprocessing and postprocessing 

53 Note: Ideally use the `PredictionPipeline` in a with statement 

54 (as a context manager). 

55 """ 

56 

57 def __init__( 

58 self, 

59 *, 

60 name: str, 

61 model_description: AnyModelDescr, 

62 preprocessing: List[Processing], 

63 postprocessing: List[Processing], 

64 model_adapter: ModelAdapter, 

65 default_ns: Optional[BlocksizeParameter] = None, 

66 default_blocksize_parameter: BlocksizeParameter = 10, 

67 default_batch_size: int = 1, 

68 ) -> None: 

69 """Use `create_prediction_pipeline` to create a `PredictionPipeline`""" 

70 super().__init__() 

71 default_blocksize_parameter = default_ns or default_blocksize_parameter 

72 if default_ns is not None: 

73 warnings.warn( 

74 "Argument `default_ns` is deprecated in favor of" 

75 + " `default_blocksize_paramter` and will be removed soon." 

76 ) 

77 del default_ns 

78 

79 if model_description.run_mode: 

80 warnings.warn( 

81 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 

82 ) 

83 

84 self.name = name 

85 self._preprocessing = preprocessing 

86 self._postprocessing = postprocessing 

87 

88 self.model_description = model_description 

89 if isinstance(model_description, v0_4.ModelDescr): 

90 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 

91 self._block_transform = None 

92 else: 

93 default_output_halo = { 

94 t.id: { 

95 a.id: Halo(a.halo, a.halo) 

96 for a in t.axes 

97 if isinstance(a, v0_5.WithHalo) 

98 } 

99 for t in model_description.outputs 

100 } 

101 self._default_input_halo = get_input_halo( 

102 model_description, default_output_halo 

103 ) 

104 self._block_transform = get_block_transform(model_description) 

105 

106 self._default_blocksize_parameter = default_blocksize_parameter 

107 self._default_batch_size = default_batch_size 

108 

109 self._input_ids = get_member_ids(model_description.inputs) 

110 self._output_ids = get_member_ids(model_description.outputs) 

111 

112 self._adapter: ModelAdapter = model_adapter 

113 

114 def __enter__(self): 

115 self.load() 

116 return self 

117 

118 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 

119 self.unload() 

120 return False 

121 

122 def predict_sample_block( 

123 self, 

124 sample_block: SampleBlockWithOrigin, 

125 skip_preprocessing: bool = False, 

126 skip_postprocessing: bool = False, 

127 ) -> SampleBlock: 

128 if isinstance(self.model_description, v0_4.ModelDescr): 

129 raise NotImplementedError( 

130 f"predict_sample_block not implemented for model {self.model_description.format_version}" 

131 ) 

132 else: 

133 assert self._block_transform is not None 

134 

135 if not skip_preprocessing: 

136 self.apply_preprocessing(sample_block) 

137 

138 output_meta = sample_block.get_transformed_meta(self._block_transform) 

139 local_output = self._adapter.forward(sample_block) 

140 

141 output = output_meta.with_data(local_output.members, stat=local_output.stat) 

142 if not skip_postprocessing: 

143 self.apply_postprocessing(output) 

144 

145 return output 

146 

147 def predict_sample_without_blocking( 

148 self, 

149 sample: Sample, 

150 skip_preprocessing: bool = False, 

151 skip_postprocessing: bool = False, 

152 ) -> Sample: 

153 """predict a sample. 

154 The sample's tensor shapes have to match the model's input tensor description. 

155 If that is not the case, consider `predict_sample_with_blocking`""" 

156 

157 if not skip_preprocessing: 

158 self.apply_preprocessing(sample) 

159 

160 output = self._adapter.forward(sample) 

161 if not skip_postprocessing: 

162 self.apply_postprocessing(output) 

163 

164 return output 

165 

166 def get_output_sample_id(self, input_sample_id: SampleId): 

167 warnings.warn( 

168 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 

169 + " removed soon. Output sample id is equal to input sample id, hence this" 

170 + " function is not needed." 

171 ) 

172 return input_sample_id 

173 

174 def predict_sample_with_fixed_blocking( 

175 self, 

176 sample: Sample, 

177 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 

178 *, 

179 skip_preprocessing: bool = False, 

180 skip_postprocessing: bool = False, 

181 ) -> Sample: 

182 if not skip_preprocessing: 

183 self.apply_preprocessing(sample) 

184 

185 n_blocks, input_blocks = sample.split_into_blocks( 

186 input_block_shape, 

187 halo=self._default_input_halo, 

188 pad_mode="reflect", 

189 ) 

190 input_blocks = list(input_blocks) 

191 predicted_blocks: List[SampleBlock] = [] 

192 logger.info( 

193 "split sample shape {} into {} blocks of {}.", 

194 {k: dict(v) for k, v in sample.shape.items()}, 

195 n_blocks, 

196 {k: dict(v) for k, v in input_block_shape.items()}, 

197 ) 

198 for b in tqdm( 

199 input_blocks, 

200 desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}", 

201 unit="block", 

202 unit_divisor=1, 

203 total=n_blocks, 

204 ): 

205 predicted_blocks.append( 

206 self.predict_sample_block( 

207 b, skip_preprocessing=True, skip_postprocessing=True 

208 ) 

209 ) 

210 

211 predicted_sample = Sample.from_blocks(predicted_blocks) 

212 if not skip_postprocessing: 

213 self.apply_postprocessing(predicted_sample) 

214 

215 return predicted_sample 

216 

217 def predict_sample_with_blocking( 

218 self, 

219 sample: Sample, 

220 skip_preprocessing: bool = False, 

221 skip_postprocessing: bool = False, 

222 ns: Optional[ 

223 Union[ 

224 v0_5.ParameterizedSize_N, 

225 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

226 ] 

227 ] = None, 

228 batch_size: Optional[int] = None, 

229 ) -> Sample: 

230 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 

231 

232 if isinstance(self.model_description, v0_4.ModelDescr): 

233 raise NotImplementedError( 

234 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

235 + f" {self.model_description.name}." 

236 + " Consider using `predict_sample_with_fixed_blocking`" 

237 ) 

238 

239 ns = ns or self._default_blocksize_parameter 

240 if isinstance(ns, int): 

241 ns = { 

242 (ipt.id, a.id): ns 

243 for ipt in self.model_description.inputs 

244 for a in ipt.axes 

245 if isinstance(a.size, v0_5.ParameterizedSize) 

246 } 

247 input_block_shape = self.model_description.get_tensor_sizes( 

248 ns, batch_size or self._default_batch_size 

249 ).inputs 

250 

251 return self.predict_sample_with_fixed_blocking( 

252 sample, 

253 input_block_shape=input_block_shape, 

254 skip_preprocessing=skip_preprocessing, 

255 skip_postprocessing=skip_postprocessing, 

256 ) 

257 

258 # def predict( 

259 # self, 

260 # inputs: Predict_IO, 

261 # skip_preprocessing: bool = False, 

262 # skip_postprocessing: bool = False, 

263 # ) -> Predict_IO: 

264 # """Run model prediction **including** pre/postprocessing.""" 

265 

266 # if isinstance(inputs, Sample): 

267 # return self.predict_sample_with_blocking( 

268 # inputs, 

269 # skip_preprocessing=skip_preprocessing, 

270 # skip_postprocessing=skip_postprocessing, 

271 # ) 

272 # elif isinstance(inputs, collections.abc.Iterable): 

273 # return ( 

274 # self.predict( 

275 # ipt, 

276 # skip_preprocessing=skip_preprocessing, 

277 # skip_postprocessing=skip_postprocessing, 

278 # ) 

279 # for ipt in inputs 

280 # ) 

281 # else: 

282 # assert_never(inputs) 

283 

284 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

285 """apply preprocessing in-place, also updates sample stats""" 

286 for op in self._preprocessing: 

287 op(sample) 

288 

289 def apply_postprocessing( 

290 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 

291 ) -> None: 

292 """apply postprocessing in-place, also updates samples stats""" 

293 for op in self._postprocessing: 

294 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 

295 op(sample) 

296 elif not isinstance(op, BlockedOperator): 

297 raise NotImplementedError( 

298 "block wise update of output statistics not yet implemented" 

299 ) 

300 else: 

301 op(sample) 

302 

303 def load(self): 

304 """ 

305 optional step: load model onto devices before calling forward if not using it as context manager 

306 """ 

307 pass 

308 

309 def unload(self): 

310 """ 

311 free any device memory in use 

312 """ 

313 self._adapter.unload() 

314 

315 

316def create_prediction_pipeline( 

317 bioimageio_model: AnyModelDescr, 

318 *, 

319 devices: Optional[Sequence[str]] = None, 

320 weight_format: Optional[SupportedWeightsFormat] = None, 

321 weights_format: Optional[SupportedWeightsFormat] = None, 

322 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 

323 keep_updating_initial_dataset_statistics: bool = False, 

324 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( 

325 {} 

326 ), 

327 model_adapter: Optional[ModelAdapter] = None, 

328 ns: Optional[BlocksizeParameter] = None, 

329 default_blocksize_parameter: BlocksizeParameter = 10, 

330 **deprecated_kwargs: Any, 

331) -> PredictionPipeline: 

332 """ 

333 Creates prediction pipeline which includes: 

334 * computation of input statistics 

335 * preprocessing 

336 * model prediction 

337 * computation of output statistics 

338 * postprocessing 

339 

340 Args: 

341 bioimageio_model: A bioimageio model description. 

342 devices: (optional) 

343 weight_format: deprecated in favor of **weights_format** 

344 weights_format: (optional) Use a specific **weights_format** rather than 

345 choosing one automatically. 

346 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be 

347 created to run inference with the **bioimageio_model**. 

348 dataset_for_initial_statistics: (optional) If preprocessing steps require input 

349 dataset statistics, **dataset_for_initial_statistics** allows you to 

350 specifcy a dataset from which these statistics are computed. 

351 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want 

352 to update dataset statistics with each processed sample. 

353 fixed_dataset_statistics: (optional) Allows you to specify a mapping of 

354 `DatasetMeasure`s to precomputed `MeasureValue`s. 

355 model_adapter: (optional) Allows you to use a custom **model_adapter** instead 

356 of creating one according to the present/selected **weights_format**. 

357 ns: deprecated in favor of **default_blocksize_parameter** 

358 default_blocksize_parameter: Allows to control the default block size for 

359 blockwise predictions, see `BlocksizeParameter`. 

360 

361 """ 

362 weights_format = weight_format or weights_format 

363 del weight_format 

364 default_blocksize_parameter = ns or default_blocksize_parameter 

365 del ns 

366 if deprecated_kwargs: 

367 warnings.warn( 

368 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 

369 ) 

370 

371 model_adapter = model_adapter or create_model_adapter( 

372 model_description=bioimageio_model, 

373 devices=devices, 

374 weight_format_priority_order=weights_format and (weights_format,), 

375 ) 

376 

377 input_ids = get_member_ids(bioimageio_model.inputs) 

378 

379 def dataset(): 

380 common_stat: Stat = {} 

381 for i, x in enumerate(dataset_for_initial_statistics): 

382 if isinstance(x, Sample): 

383 yield x 

384 else: 

385 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 

386 

387 preprocessing, postprocessing = setup_pre_and_postprocessing( 

388 bioimageio_model, 

389 dataset(), 

390 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 

391 fixed_dataset_stats=fixed_dataset_statistics, 

392 ) 

393 

394 return PredictionPipeline( 

395 name=bioimageio_model.name, 

396 model_description=bioimageio_model, 

397 model_adapter=model_adapter, 

398 preprocessing=preprocessing, 

399 postprocessing=postprocessing, 

400 default_blocksize_parameter=default_blocksize_parameter, 

401 )