Coverage for src / bioimageio / core / _prediction_pipeline.py: 84%

171 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import warnings 

2from types import MappingProxyType 

3from typing import ( 

4 Any, 

5 Iterable, 

6 List, 

7 Literal, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 Tuple, 

12 TypeVar, 

13 Union, 

14) 

15 

16from loguru import logger 

17from tqdm import tqdm 

18 

19from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

20 

21from ._op_base import BlockwiseOperator, SamplewiseOperator 

22from .axis import AxisId, PerAxis 

23from .common import ( 

24 BlocksizeParameter, 

25 Halo, 

26 MemberId, 

27 PerMember, 

28 SampleId, 

29 SupportedWeightsFormat, 

30) 

31from .digest_spec import ( 

32 get_block_transform, 

33 get_input_halo, 

34 get_member_ids, 

35) 

36from .model_adapters import ModelAdapter, create_model_adapter 

37from .model_adapters import get_weight_formats as get_weight_formats 

38from .proc_ops import Processing 

39from .proc_setup import setup_pre_and_postprocessing 

40from .sample import Sample, SampleBlock 

41from .stat_measures import Measure, MeasureValue, Stat 

42from .tensor import Tensor 

43 

44Predict_IO = TypeVar( 

45 "Predict_IO", 

46 Sample, 

47 Iterable[Sample], 

48) 

49 

50 

51class PredictionPipeline: 

52 """ 

53 Represents model computation including preprocessing and postprocessing 

54 Note: Ideally use the `PredictionPipeline` in a with statement 

55 (as a context manager). 

56 """ 

57 

58 def __init__( 

59 self, 

60 *, 

61 name: str, 

62 model_description: AnyModelDescr, 

63 preprocessing: List[Processing], 

64 postprocessing: List[Processing], 

65 model_adapter: ModelAdapter, 

66 default_ns: Optional[BlocksizeParameter] = None, 

67 default_blocksize_parameter: BlocksizeParameter = 10, 

68 default_batch_size: int = 1, 

69 ) -> None: 

70 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults.""" 

71 super().__init__() 

72 default_blocksize_parameter = default_ns or default_blocksize_parameter 

73 if default_ns is not None: 

74 warnings.warn( 

75 "Argument `default_ns` is deprecated in favor of" 

76 + " `default_blocksize_paramter` and will be removed soon." 

77 ) 

78 del default_ns 

79 

80 if model_description.run_mode: 

81 warnings.warn( 

82 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 

83 ) 

84 

85 self.name = name 

86 # split preprocessing into samplewise and blockwise. samplewise preprocessing is all preprocessing up to including the last samplewise operator, blockwise preprocessing are the remaining blockwise operators. 

87 # I.e. some samplewise preprocessing may be a blockwise op (at some point followed by a samplewise op). 

88 self._samplewise_preprocessing: List[ 

89 Union[SamplewiseOperator, BlockwiseOperator] 

90 ] = [] 

91 self._blockwise_preprocessing: List[BlockwiseOperator] = [] 

92 for op in preprocessing[::-1]: 

93 if isinstance(op, BlockwiseOperator) and not self._samplewise_preprocessing: 

94 self._blockwise_preprocessing.insert(0, op) 

95 else: 

96 self._samplewise_preprocessing.insert(0, op) 

97 # split postprocessing analougly, but here we start blockwise and switch to samplewise at the first samplewise operator. 

98 self._blockwise_postprocessing: List[BlockwiseOperator] = [] 

99 self._samplewise_postprocessing: List[ 

100 Union[BlockwiseOperator, SamplewiseOperator] 

101 ] = [] 

102 for op in postprocessing: 

103 if ( 

104 isinstance(op, BlockwiseOperator) 

105 and not self._samplewise_postprocessing 

106 ): 

107 self._blockwise_postprocessing.append(op) 

108 else: 

109 self._samplewise_postprocessing.append(op) 

110 

111 self.pad_mode = ( 

112 {} 

113 if isinstance(model_description, v0_4.ModelDescr) 

114 else { 

115 descr.id: descr.pad or v0_5.SymmetricPadding() 

116 for descr in model_description.inputs 

117 } 

118 ) 

119 self.model_description = model_description 

120 if isinstance(model_description, v0_4.ModelDescr): 

121 self._default_output_halo: PerMember[PerAxis[Halo]] = {} 

122 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 

123 self._block_transform = None 

124 else: 

125 self._default_output_halo = { 

126 t.id: { 

127 a.id: Halo(a.halo, a.halo) 

128 for a in t.axes 

129 if isinstance(a, v0_5.WithHalo) 

130 } 

131 for t in model_description.outputs 

132 } 

133 self._default_input_halo = get_input_halo( 

134 model_description, self._default_output_halo 

135 ) 

136 self._block_transform = get_block_transform(model_description) 

137 

138 self._default_blocksize_parameter = default_blocksize_parameter 

139 self._default_batch_size = default_batch_size 

140 

141 self._input_ids = get_member_ids(model_description.inputs) 

142 self._output_ids = get_member_ids(model_description.outputs) 

143 

144 self._adapter: ModelAdapter = model_adapter 

145 

146 def __enter__(self): 

147 self.load() 

148 return self 

149 

150 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 

151 self.unload() 

152 return False 

153 

154 @property 

155 def has_blockwise_preprocessing(self) -> bool: 

156 """`True` if all preprocessing operators in the pipeline are blockwise.""" 

157 return bool(self._blockwise_preprocessing) 

158 

159 @property 

160 def has_blockwise_postprocessing(self) -> bool: 

161 """`True` if all postprocessing operators in the pipeline are blockwise.""" 

162 return bool(self._blockwise_postprocessing) 

163 

164 def _raise_for_non_blockwise_processing( 

165 self, proc_type: Literal["preprocessing", "postprocessing"] 

166 ): 

167 ops = ( 

168 self._samplewise_preprocessing 

169 if proc_type == "preprocessing" 

170 else self._samplewise_postprocessing 

171 ) 

172 non_blockwise = [ 

173 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator) 

174 ] 

175 if non_blockwise: 

176 raise NotImplementedError( 

177 f"Blockwise {proc_type} for {non_blockwise} not implemented." 

178 ) 

179 

180 def raise_for_non_blockwise_preprocessing(self): 

181 """ 

182 Raises: 

183 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline 

184 """ 

185 self._raise_for_non_blockwise_processing("preprocessing") 

186 

187 def raise_for_non_blockwise_postprocessing(self): 

188 """ 

189 Raises: 

190 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline 

191 """ 

192 self._raise_for_non_blockwise_processing("postprocessing") 

193 

194 def predict_sample_block( 

195 self, 

196 sample_block: SampleBlock, 

197 skip_preprocessing: bool = False, 

198 skip_postprocessing: bool = False, 

199 ) -> SampleBlock: 

200 if isinstance(self.model_description, v0_4.ModelDescr): 

201 raise NotImplementedError( 

202 f"predict_sample_block not implemented for model {self.model_description.format_version}" 

203 ) 

204 else: 

205 assert self._block_transform is not None 

206 

207 if not skip_preprocessing: 

208 self.raise_for_non_blockwise_preprocessing() 

209 

210 if not skip_postprocessing: 

211 self.raise_for_non_blockwise_postprocessing() 

212 

213 if not skip_preprocessing: 

214 self.apply_preprocessing(sample_block) 

215 

216 output_meta = sample_block.get_transformed_meta(self._block_transform) 

217 local_output = self._adapter.forward(sample_block) 

218 

219 output = output_meta.with_data(local_output.members, stat=local_output.stat) 

220 if not skip_postprocessing: 

221 self.apply_postprocessing(output) 

222 

223 return output 

224 

225 def predict_sample_without_blocking( 

226 self, 

227 sample: Sample, 

228 skip_preprocessing: bool = False, 

229 skip_postprocessing: bool = False, 

230 skip_input_padding: bool = False, 

231 skip_output_cropping: bool = False, 

232 ) -> Sample: 

233 """predict a whole sample 

234 

235 Args: 

236 sample: input sample 

237 skip_preprocessing: if `True`, skip all preprocessing steps. 

238 skip_postprocessing: if `True`, skip all postprocessing steps. 

239 skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos. 

240 skip_output_cropping: if `True`, skip cropping any output halos from the model output. 

241 Note: 

242 The sample's tensor shapes have to match the model's input tensor description. 

243 If that is not the case, consider `predict_sample_with_blocking` 

244 """ 

245 

246 if not skip_input_padding: 

247 sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode) 

248 

249 if not skip_preprocessing: 

250 self.apply_preprocessing(sample) 

251 

252 output = self._adapter.forward(sample) 

253 if not skip_postprocessing: 

254 self.apply_postprocessing(output) 

255 

256 if not skip_output_cropping: 

257 output.members = { 

258 m: t 

259 if m not in self._default_output_halo 

260 else t[ 

261 { 

262 a: slice(h.left, None if h.right == 0 else -h.right) 

263 for a, h in self._default_output_halo[m].items() 

264 } 

265 ] 

266 for m, t in output.members.items() 

267 } 

268 

269 return output 

270 

271 def get_output_sample_id(self, input_sample_id: SampleId): 

272 warnings.warn( 

273 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 

274 + " removed soon. Output sample id is equal to input sample id, hence this" 

275 + " function is not needed." 

276 ) 

277 return input_sample_id 

278 

279 def predict_sample_with_fixed_blocking( 

280 self, 

281 sample: Sample, 

282 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 

283 *, 

284 skip_preprocessing: bool = False, 

285 skip_postprocessing: bool = False, 

286 ) -> Sample: 

287 """Predict `sample` with given `input_block_shape`. 

288 

289 Note: 

290 `input_block_shape` is expected to be a valid input shape for the model. 

291 """ 

292 if not skip_preprocessing: 

293 for op in self._samplewise_preprocessing: 

294 op(sample) 

295 

296 n_blocks, input_blocks = sample.split_into_blocks( 

297 input_block_shape, 

298 halo=self._default_input_halo, 

299 pad_mode=self.pad_mode, 

300 ) 

301 input_blocks = list(input_blocks) 

302 predicted_blocks: List[SampleBlock] = [] 

303 logger.info( 

304 "split sample shape {} into {} blocks of {}.", 

305 {k: dict(v) for k, v in sample.shape.items()}, 

306 n_blocks, 

307 {k: dict(v) for k, v in input_block_shape.items()}, 

308 ) 

309 for b in tqdm( 

310 input_blocks, 

311 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", 

312 unit="block", 

313 unit_divisor=1, 

314 total=n_blocks, 

315 ): 

316 if not skip_preprocessing: 

317 for op in self._blockwise_preprocessing: 

318 op(b) 

319 

320 predicted_blocks.append( 

321 self.predict_sample_block( 

322 b, skip_preprocessing=True, skip_postprocessing=True 

323 ) 

324 ) 

325 if not skip_postprocessing: 

326 for op in self._blockwise_postprocessing: 

327 op(predicted_blocks[-1]) 

328 

329 predicted_sample = Sample.from_blocks(predicted_blocks) 

330 if not skip_postprocessing: 

331 for op in self._samplewise_postprocessing: 

332 op(predicted_sample) 

333 

334 return predicted_sample 

335 

336 def predict_sample_with_blocking( 

337 self, 

338 sample: Sample, 

339 skip_preprocessing: bool = False, 

340 skip_postprocessing: bool = False, 

341 ns: Optional[ 

342 Union[ 

343 v0_5.ParameterizedSize_N, 

344 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

345 ] 

346 ] = None, 

347 batch_size: Optional[int] = None, 

348 ) -> Sample: 

349 """Predict a sample by splitting it into blocks according to the mode 

350 

351 The `ns` parameter allow scaling the model's default input block size. 

352 """ 

353 

354 if isinstance(self.model_description, v0_4.ModelDescr): 

355 raise NotImplementedError( 

356 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

357 + f" {self.model_description.name}." 

358 + " Consider using `predict_sample_with_fixed_blocking`" 

359 ) 

360 

361 ns = ns or self._default_blocksize_parameter 

362 if isinstance(ns, int): 

363 ns = { 

364 (ipt.id, a.id): ns 

365 for ipt in self.model_description.inputs 

366 for a in ipt.axes 

367 if isinstance(a.size, v0_5.ParameterizedSize) 

368 } 

369 input_block_shape = self.model_description.get_tensor_sizes( 

370 ns, batch_size or self._default_batch_size 

371 ).inputs 

372 

373 return self.predict_sample_with_fixed_blocking( 

374 sample, 

375 input_block_shape=input_block_shape, 

376 skip_preprocessing=skip_preprocessing, 

377 skip_postprocessing=skip_postprocessing, 

378 ) 

379 

380 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

381 """apply preprocessing in-place, also may updates sample stats""" 

382 if isinstance(sample, SampleBlock): 

383 self.raise_for_non_blockwise_preprocessing() 

384 

385 for op in self._samplewise_preprocessing + self._blockwise_preprocessing: 

386 if isinstance(sample, SampleBlock): 

387 assert isinstance(op, BlockwiseOperator) 

388 op(sample) 

389 else: 

390 op(sample) 

391 

392 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

393 """apply postprocessing in-place, also may updates samples stats""" 

394 if isinstance(sample, SampleBlock): 

395 self.raise_for_non_blockwise_postprocessing() 

396 

397 for op in self._blockwise_postprocessing + self._samplewise_postprocessing: 

398 if isinstance(sample, SampleBlock): 

399 assert isinstance(op, BlockwiseOperator) 

400 op(sample) 

401 else: 

402 op(sample) 

403 

404 def load(self): 

405 """ 

406 optional step: load model onto devices before calling forward if not using it as context manager 

407 """ 

408 pass 

409 

410 def unload(self): 

411 """ 

412 free any device memory in use 

413 """ 

414 self._adapter.unload() 

415 

416 

417def create_prediction_pipeline( 

418 bioimageio_model: AnyModelDescr, 

419 *, 

420 devices: Optional[Sequence[str]] = None, 

421 weight_format: Optional[SupportedWeightsFormat] = None, 

422 weights_format: Optional[SupportedWeightsFormat] = None, 

423 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 

424 keep_updating_initial_dataset_statistics: bool = False, 

425 fixed_dataset_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), 

426 model_adapter: Optional[ModelAdapter] = None, 

427 ns: Optional[BlocksizeParameter] = None, 

428 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo 

429 **deprecated_kwargs: Any, 

430) -> PredictionPipeline: 

431 """ 

432 Creates prediction pipeline which includes: 

433 * computation of input statistics 

434 * preprocessing 

435 * model prediction 

436 * computation of output statistics 

437 * postprocessing 

438 

439 Args: 

440 bioimageio_model: A bioimageio model description. 

441 devices: (optional) 

442 weight_format: deprecated in favor of **weights_format** 

443 weights_format: (optional) Use a specific **weights_format** rather than 

444 choosing one automatically. 

445 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be 

446 created to run inference with the **bioimageio_model**. 

447 dataset_for_initial_statistics: (optional) If preprocessing steps require input 

448 dataset statistics, **dataset_for_initial_statistics** allows you to 

449 specifcy a dataset from which these statistics are computed. 

450 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want 

451 to update dataset statistics with each processed sample. 

452 fixed_dataset_statistics: (optional) Precomputed dataset (and optionally sample) statistics. 

453 Any included sample statistics will not be calculated on the fly and it is the callers 

454 responsibility to use samples with the corresponding statistics availble in `sample.stat`. 

455 model_adapter: (optional) Allows you to use a custom **model_adapter** instead 

456 of creating one according to the present/selected **weights_format**. 

457 ns: deprecated in favor of **default_blocksize_parameter** 

458 default_blocksize_parameter: Allows to control the default block size for 

459 blockwise predictions, see `BlocksizeParameter`. 

460 

461 """ 

462 weights_format = weight_format or weights_format 

463 del weight_format 

464 default_blocksize_parameter = ns or default_blocksize_parameter 

465 del ns 

466 if deprecated_kwargs: 

467 warnings.warn( 

468 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 

469 ) 

470 

471 model_adapter = model_adapter or create_model_adapter( 

472 model_description=bioimageio_model, 

473 devices=devices, 

474 weight_format_priority_order=weights_format and (weights_format,), 

475 ) 

476 

477 input_ids = get_member_ids(bioimageio_model.inputs) 

478 

479 def dataset(): 

480 common_stat: Stat = {} 

481 for i, x in enumerate(dataset_for_initial_statistics): 

482 if isinstance(x, Sample): 

483 yield x 

484 else: 

485 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 

486 

487 preprocessing, postprocessing = setup_pre_and_postprocessing( 

488 bioimageio_model, 

489 dataset(), 

490 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 

491 fixed_dataset_stats=fixed_dataset_statistics, 

492 ) 

493 

494 return PredictionPipeline( 

495 name=bioimageio_model.name, 

496 model_description=bioimageio_model, 

497 model_adapter=model_adapter, 

498 preprocessing=preprocessing, 

499 postprocessing=postprocessing, 

500 default_blocksize_parameter=default_blocksize_parameter, 

501 )