Coverage for src / bioimageio / core / _prediction_pipeline.py: 82%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1import warnings 

2from types import MappingProxyType 

3from typing import ( 

4 Any, 

5 Iterable, 

6 List, 

7 Literal, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 Tuple, 

12 TypeVar, 

13 Union, 

14) 

15 

16from loguru import logger 

17from tqdm import tqdm 

18 

19from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

20 

21from ._op_base import BlockwiseOperator 

22from .axis import AxisId, PerAxis 

23from .common import ( 

24 BlocksizeParameter, 

25 Halo, 

26 MemberId, 

27 PerMember, 

28 SampleId, 

29 SupportedWeightsFormat, 

30) 

31from .digest_spec import ( 

32 get_block_transform, 

33 get_input_halo, 

34 get_member_ids, 

35) 

36from .model_adapters import ModelAdapter, create_model_adapter 

37from .model_adapters import get_weight_formats as get_weight_formats 

38from .proc_setup import Processing, setup_pre_and_postprocessing 

39from .sample import Sample, SampleBlock 

40from .stat_measures import DatasetMeasure, MeasureValue, Stat 

41from .tensor import Tensor 

42 

43Predict_IO = TypeVar( 

44 "Predict_IO", 

45 Sample, 

46 Iterable[Sample], 

47) 

48 

49 

50class PredictionPipeline: 

51 """ 

52 Represents model computation including preprocessing and postprocessing 

53 Note: Ideally use the `PredictionPipeline` in a with statement 

54 (as a context manager). 

55 """ 

56 

57 def __init__( 

58 self, 

59 *, 

60 name: str, 

61 model_description: AnyModelDescr, 

62 preprocessing: List[Processing], 

63 postprocessing: List[Processing], 

64 model_adapter: ModelAdapter, 

65 default_ns: Optional[BlocksizeParameter] = None, 

66 default_blocksize_parameter: BlocksizeParameter = 10, 

67 default_batch_size: int = 1, 

68 ) -> None: 

69 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults.""" 

70 super().__init__() 

71 default_blocksize_parameter = default_ns or default_blocksize_parameter 

72 if default_ns is not None: 

73 warnings.warn( 

74 "Argument `default_ns` is deprecated in favor of" 

75 + " `default_blocksize_paramter` and will be removed soon." 

76 ) 

77 del default_ns 

78 

79 if model_description.run_mode: 

80 warnings.warn( 

81 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 

82 ) 

83 

84 self.name = name 

85 self._preprocessing = preprocessing 

86 self._postprocessing = postprocessing 

87 

88 self.model_description = model_description 

89 if isinstance(model_description, v0_4.ModelDescr): 

90 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 

91 self._block_transform = None 

92 else: 

93 default_output_halo = { 

94 t.id: { 

95 a.id: Halo(a.halo, a.halo) 

96 for a in t.axes 

97 if isinstance(a, v0_5.WithHalo) 

98 } 

99 for t in model_description.outputs 

100 } 

101 self._default_input_halo = get_input_halo( 

102 model_description, default_output_halo 

103 ) 

104 self._block_transform = get_block_transform(model_description) 

105 

106 self._default_blocksize_parameter = default_blocksize_parameter 

107 self._default_batch_size = default_batch_size 

108 

109 self._input_ids = get_member_ids(model_description.inputs) 

110 self._output_ids = get_member_ids(model_description.outputs) 

111 

112 self._adapter: ModelAdapter = model_adapter 

113 

114 def __enter__(self): 

115 self.load() 

116 return self 

117 

118 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 

119 self.unload() 

120 return False 

121 

122 @property 

123 def has_blockwise_preprocessing(self) -> bool: 

124 """`True` if all preprocessing operators in the pipeline are blockwise.""" 

125 return all(isinstance(op, BlockwiseOperator) for op in self._preprocessing) 

126 

127 @property 

128 def has_blockwise_postprocessing(self) -> bool: 

129 """`True` if all postprocessing operators in the pipeline are blockwise.""" 

130 return all(isinstance(op, BlockwiseOperator) for op in self._postprocessing) 

131 

132 def _raise_for_non_blockwise_processing( 

133 self, proc_type: Literal["preprocessing", "postprocessing"] 

134 ): 

135 ops = ( 

136 self._preprocessing 

137 if proc_type == "preprocessing" 

138 else self._postprocessing 

139 ) 

140 non_blockwise = [ 

141 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator) 

142 ] 

143 if non_blockwise: 

144 raise NotImplementedError( 

145 f"Blockwise {proc_type} for non-blockwise operators {non_blockwise} not implemented." 

146 ) 

147 

148 def raise_for_non_blockwise_preprocessing(self): 

149 """ 

150 Raises: 

151 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline 

152 """ 

153 self._raise_for_non_blockwise_processing("preprocessing") 

154 

155 def raise_for_non_blockwise_postprocessing(self): 

156 """ 

157 Raises: 

158 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline 

159 """ 

160 self._raise_for_non_blockwise_processing("postprocessing") 

161 

162 def predict_sample_block( 

163 self, 

164 sample_block: SampleBlock, 

165 skip_preprocessing: bool = False, 

166 skip_postprocessing: bool = False, 

167 ) -> SampleBlock: 

168 if isinstance(self.model_description, v0_4.ModelDescr): 

169 raise NotImplementedError( 

170 f"predict_sample_block not implemented for model {self.model_description.format_version}" 

171 ) 

172 else: 

173 assert self._block_transform is not None 

174 

175 if not skip_preprocessing: 

176 self.raise_for_non_blockwise_preprocessing() 

177 

178 if not skip_postprocessing: 

179 self.raise_for_non_blockwise_postprocessing() 

180 

181 if not skip_preprocessing: 

182 self.apply_preprocessing(sample_block) 

183 

184 output_meta = sample_block.get_transformed_meta(self._block_transform) 

185 local_output = self._adapter.forward(sample_block) 

186 

187 output = output_meta.with_data(local_output.members, stat=local_output.stat) 

188 if not skip_postprocessing: 

189 self.apply_postprocessing(output) 

190 

191 return output 

192 

193 def predict_sample_without_blocking( 

194 self, 

195 sample: Sample, 

196 skip_preprocessing: bool = False, 

197 skip_postprocessing: bool = False, 

198 ) -> Sample: 

199 """predict a whole sample 

200 

201 Note: 

202 The sample's tensor shapes have to match the model's input tensor description. 

203 If that is not the case, consider `predict_sample_with_blocking` 

204 """ 

205 

206 if not skip_preprocessing: 

207 self.apply_preprocessing(sample) 

208 

209 output = self._adapter.forward(sample) 

210 if not skip_postprocessing: 

211 self.apply_postprocessing(output) 

212 

213 return output 

214 

215 def get_output_sample_id(self, input_sample_id: SampleId): 

216 warnings.warn( 

217 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 

218 + " removed soon. Output sample id is equal to input sample id, hence this" 

219 + " function is not needed." 

220 ) 

221 return input_sample_id 

222 

223 def predict_sample_with_fixed_blocking( 

224 self, 

225 sample: Sample, 

226 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 

227 *, 

228 skip_preprocessing: bool = False, 

229 skip_postprocessing: bool = False, 

230 ) -> Sample: 

231 """Predict `sample` with given `input_block_shape`. 

232 

233 Note: 

234 `input_block_shape` is expected to be a valid input shape for the model. 

235 """ 

236 if not skip_preprocessing: 

237 self.apply_preprocessing(sample) 

238 

239 n_blocks, input_blocks = sample.split_into_blocks( 

240 input_block_shape, 

241 halo=self._default_input_halo, 

242 pad_mode="reflect", 

243 ) 

244 input_blocks = list(input_blocks) 

245 predicted_blocks: List[SampleBlock] = [] 

246 logger.info( 

247 "split sample shape {} into {} blocks of {}.", 

248 {k: dict(v) for k, v in sample.shape.items()}, 

249 n_blocks, 

250 {k: dict(v) for k, v in input_block_shape.items()}, 

251 ) 

252 for b in tqdm( 

253 input_blocks, 

254 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", 

255 unit="block", 

256 unit_divisor=1, 

257 total=n_blocks, 

258 ): 

259 predicted_blocks.append( 

260 self.predict_sample_block( 

261 b, skip_preprocessing=True, skip_postprocessing=True 

262 ) 

263 ) 

264 

265 predicted_sample = Sample.from_blocks(predicted_blocks) 

266 if not skip_postprocessing: 

267 self.apply_postprocessing(predicted_sample) 

268 

269 return predicted_sample 

270 

271 def predict_sample_with_blocking( 

272 self, 

273 sample: Sample, 

274 skip_preprocessing: bool = False, 

275 skip_postprocessing: bool = False, 

276 ns: Optional[ 

277 Union[ 

278 v0_5.ParameterizedSize_N, 

279 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

280 ] 

281 ] = None, 

282 batch_size: Optional[int] = None, 

283 ) -> Sample: 

284 """Predict a sample by splitting it into blocks according to the mode 

285 

286 The `ns` parameter allow scaling the model's default input block size. 

287 """ 

288 

289 if isinstance(self.model_description, v0_4.ModelDescr): 

290 raise NotImplementedError( 

291 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

292 + f" {self.model_description.name}." 

293 + " Consider using `predict_sample_with_fixed_blocking`" 

294 ) 

295 

296 ns = ns or self._default_blocksize_parameter 

297 if isinstance(ns, int): 

298 ns = { 

299 (ipt.id, a.id): ns 

300 for ipt in self.model_description.inputs 

301 for a in ipt.axes 

302 if isinstance(a.size, v0_5.ParameterizedSize) 

303 } 

304 input_block_shape = self.model_description.get_tensor_sizes( 

305 ns, batch_size or self._default_batch_size 

306 ).inputs 

307 

308 return self.predict_sample_with_fixed_blocking( 

309 sample, 

310 input_block_shape=input_block_shape, 

311 skip_preprocessing=skip_preprocessing, 

312 skip_postprocessing=skip_postprocessing, 

313 ) 

314 

315 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

316 """apply preprocessing in-place, also may updates sample stats""" 

317 if isinstance(sample, SampleBlock): 

318 self.raise_for_non_blockwise_preprocessing() 

319 

320 for op in self._preprocessing: 

321 if isinstance(sample, SampleBlock): 

322 assert isinstance(op, BlockwiseOperator) 

323 op(sample) 

324 else: 

325 op(sample) 

326 

327 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

328 """apply postprocessing in-place, also may updates samples stats""" 

329 if isinstance(sample, SampleBlock): 

330 self.raise_for_non_blockwise_postprocessing() 

331 

332 for op in self._postprocessing: 

333 if isinstance(sample, SampleBlock): 

334 assert isinstance(op, BlockwiseOperator) 

335 op(sample) 

336 else: 

337 op(sample) 

338 

339 def load(self): 

340 """ 

341 optional step: load model onto devices before calling forward if not using it as context manager 

342 """ 

343 pass 

344 

345 def unload(self): 

346 """ 

347 free any device memory in use 

348 """ 

349 self._adapter.unload() 

350 

351 

352def create_prediction_pipeline( 

353 bioimageio_model: AnyModelDescr, 

354 *, 

355 devices: Optional[Sequence[str]] = None, 

356 weight_format: Optional[SupportedWeightsFormat] = None, 

357 weights_format: Optional[SupportedWeightsFormat] = None, 

358 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 

359 keep_updating_initial_dataset_statistics: bool = False, 

360 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( 

361 {} 

362 ), 

363 model_adapter: Optional[ModelAdapter] = None, 

364 ns: Optional[BlocksizeParameter] = None, 

365 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo 

366 **deprecated_kwargs: Any, 

367) -> PredictionPipeline: 

368 """ 

369 Creates prediction pipeline which includes: 

370 * computation of input statistics 

371 * preprocessing 

372 * model prediction 

373 * computation of output statistics 

374 * postprocessing 

375 

376 Args: 

377 bioimageio_model: A bioimageio model description. 

378 devices: (optional) 

379 weight_format: deprecated in favor of **weights_format** 

380 weights_format: (optional) Use a specific **weights_format** rather than 

381 choosing one automatically. 

382 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be 

383 created to run inference with the **bioimageio_model**. 

384 dataset_for_initial_statistics: (optional) If preprocessing steps require input 

385 dataset statistics, **dataset_for_initial_statistics** allows you to 

386 specifcy a dataset from which these statistics are computed. 

387 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want 

388 to update dataset statistics with each processed sample. 

389 fixed_dataset_statistics: (optional) Allows you to specify a mapping of 

390 `DatasetMeasure`s to precomputed `MeasureValue`s. 

391 model_adapter: (optional) Allows you to use a custom **model_adapter** instead 

392 of creating one according to the present/selected **weights_format**. 

393 ns: deprecated in favor of **default_blocksize_parameter** 

394 default_blocksize_parameter: Allows to control the default block size for 

395 blockwise predictions, see `BlocksizeParameter`. 

396 

397 """ 

398 weights_format = weight_format or weights_format 

399 del weight_format 

400 default_blocksize_parameter = ns or default_blocksize_parameter 

401 del ns 

402 if deprecated_kwargs: 

403 warnings.warn( 

404 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 

405 ) 

406 

407 model_adapter = model_adapter or create_model_adapter( 

408 model_description=bioimageio_model, 

409 devices=devices, 

410 weight_format_priority_order=weights_format and (weights_format,), 

411 ) 

412 

413 input_ids = get_member_ids(bioimageio_model.inputs) 

414 

415 def dataset(): 

416 common_stat: Stat = {} 

417 for i, x in enumerate(dataset_for_initial_statistics): 

418 if isinstance(x, Sample): 

419 yield x 

420 else: 

421 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 

422 

423 preprocessing, postprocessing = setup_pre_and_postprocessing( 

424 bioimageio_model, 

425 dataset(), 

426 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 

427 fixed_dataset_stats=fixed_dataset_statistics, 

428 ) 

429 

430 return PredictionPipeline( 

431 name=bioimageio_model.name, 

432 model_description=bioimageio_model, 

433 model_adapter=model_adapter, 

434 preprocessing=preprocessing, 

435 postprocessing=postprocessing, 

436 default_blocksize_parameter=default_blocksize_parameter, 

437 )