Coverage for src/bioimageio/core/_prediction_pipeline.py: 83%

226 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import warnings 

2from abc import ABC, abstractmethod 

3from types import MappingProxyType 

4from typing import ( 

5 Any, 

6 Iterable, 

7 List, 

8 Literal, 

9 Mapping, 

10 NamedTuple, 

11 Optional, 

12 Sequence, 

13 Tuple, 

14 TypeVar, 

15 Union, 

16) 

17 

18from loguru import logger 

19from tqdm import tqdm 

20from typing_extensions import assert_never 

21 

22from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

23 

24from ._model_adapter import ModelAdapter 

25from ._op_base import BlockwiseOperator, SamplewiseOperator 

26from .axis import AxisId, PerAxis 

27from .backends import create_model_adapter 

28from .common import ( 

29 BlocksizeParameter, 

30 Halo, 

31 MemberId, 

32 PerMember, 

33 SampleId, 

34 SupportedWeightsFormat, 

35) 

36from .digest_spec import ( 

37 get_block_transform, 

38 get_input_halo, 

39 get_member_ids, 

40) 

41from .proc_ops import Processing 

42from .proc_setup import setup_pre_and_postprocessing 

43from .sample import Sample, SampleBlock 

44from .stat_measures import Measure, MeasureValue, Stat 

45from .tensor import Tensor 

46 

47Predict_IO = TypeVar( 

48 "Predict_IO", 

49 Sample, 

50 Iterable[Sample], 

51) 

52 

53 

54class IntermediatePrediction(NamedTuple): 

55 """Represents an intermediate prediction of a sample with blocking, including the predicted sample so far and the last predicted block. 

56 

57 The final `IntermediatePrediction` in a sequence holds the complete predicted (and postprocessed if applicable) sample.""" 

58 

59 sample: Sample 

60 last_block: SampleBlock 

61 

62 

63class _PredictionPipelineBase(ABC): 

64 def __init__( 

65 self, 

66 model_descr: AnyModelDescr, 

67 *, 

68 default_blocksize_parameter: BlocksizeParameter, 

69 default_batch_size: int, 

70 ) -> None: 

71 super().__init__() 

72 self._model_descr = model_descr 

73 self._default_blocksize_parameter = default_blocksize_parameter 

74 self._default_batch_size = default_batch_size 

75 

76 if isinstance(model_descr, v0_4.ModelDescr): 

77 self._default_output_halo: PerMember[PerAxis[Halo]] = {} 

78 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 

79 self._block_transform = None 

80 else: 

81 self._default_output_halo = { 

82 t.id: { 

83 a.id: Halo(a.halo, a.halo) 

84 for a in t.axes 

85 if isinstance(a, v0_5.WithHalo) 

86 } 

87 for t in model_descr.outputs 

88 } 

89 self._default_input_halo = get_input_halo( 

90 model_descr, self._default_output_halo 

91 ) 

92 self._block_transform = get_block_transform(model_descr) 

93 

94 self.pad_mode = ( 

95 {} 

96 if isinstance(model_descr, v0_4.ModelDescr) 

97 else { 

98 descr.id: descr.pad or v0_5.SymmetricPadding() 

99 for descr in model_descr.inputs 

100 } 

101 ) 

102 

103 @property 

104 def model_descr(self) -> AnyModelDescr: 

105 return self._model_descr 

106 

107 @property 

108 def model_description(self) -> AnyModelDescr: 

109 return self._model_descr 

110 

111 @abstractmethod 

112 def predict_sample_without_blocking( 

113 self, 

114 sample: Sample, 

115 skip_preprocessing: bool = False, 

116 skip_postprocessing: bool = False, 

117 skip_input_padding: bool = False, 

118 skip_output_cropping: bool = False, 

119 ) -> Sample: 

120 """Predict a whole sample at once. 

121 

122 Note: 

123 The sample's tensor shapes have to match the model's input tensor description. 

124 If that is not the case, consider `predict_sample_with_blocking` 

125 

126 Args: 

127 sample: input sample 

128 skip_preprocessing: if `True`, skip all preprocessing steps. 

129 skip_postprocessing: if `True`, skip all postprocessing steps. 

130 skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos. 

131 skip_output_cropping: if `True`, skip cropping any output halos from the model output. 

132 """ 

133 

134 def predict_sample_with_blocking( 

135 self, 

136 sample: Sample, 

137 skip_preprocessing: bool = False, 

138 skip_postprocessing: bool = False, 

139 ns: Optional[ 

140 Union[ 

141 v0_5.ParameterizedSize_N, 

142 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

143 ] 

144 ] = None, 

145 batch_size: Optional[int] = None, 

146 ) -> Sample: 

147 """Predict a sample by predicting sample blocks. 

148 

149 Note: For fixed/known blocksizes use `predict_sample_with_fixed_blocking`. 

150 

151 Args: 

152 sample: The sample to predict on. 

153 skip_preprocessing: If `True`, skip all preprocessing steps. 

154 skip_postprocessing: If `True`, skip all postprocessing steps. 

155 ns: Block size parameter(s) allows scaling the model's default input block size. 

156 Blocksize parameters are only applied to parameterized input axes, all other axis sizes are fixed/derived or (for output axes) data dependent. 

157 Unapplicable blocksize parameters are ignored. 

158 batch_size: Batch size to use for prediction. 

159 """ 

160 output = None 

161 for output in self.predict_sample_with_blocking_yield_intermediates( 

162 sample, 

163 skip_preprocessing=skip_preprocessing, 

164 skip_postprocessing=skip_postprocessing, 

165 ns=ns, 

166 batch_size=batch_size, 

167 )[1]: 

168 pass 

169 

170 assert output is not None, ( 

171 "No blocks were predicted, cannot return final sample." 

172 ) 

173 return output.sample 

174 

175 def predict_sample_with_fixed_blocking( 

176 self, 

177 sample: Sample, 

178 input_block_shape: PerMember[PerAxis[int]], 

179 skip_preprocessing: bool = False, 

180 skip_postprocessing: bool = False, 

181 ) -> Sample: 

182 """Predict `sample` with given `input_block_shape`. 

183 

184 Note: 

185 - `input_block_shape` is expected to be a valid input shape for the model. 

186 - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters rather than fixed block shapes. 

187 

188 Args: 

189 sample: The sample to predict on. 

190 input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis. 

191 skip_preprocessing: If `True`, skip all preprocessing steps. 

192 skip_postprocessing: If `True`, skip all postprocessing steps. 

193 """ 

194 intermediate = None 

195 for intermediate in self.predict_sample_with_fixed_blocking_yield_intermediates( 

196 sample, 

197 input_block_shape=input_block_shape, 

198 skip_preprocessing=skip_preprocessing, 

199 skip_postprocessing=skip_postprocessing, 

200 )[1]: 

201 pass 

202 

203 assert intermediate is not None, ( 

204 "No blocks were predicted, cannot return final sample." 

205 ) 

206 return intermediate.sample 

207 

208 def predict_sample_with_blocking_yield_intermediates( 

209 self, 

210 sample: Sample, 

211 skip_preprocessing: bool = False, 

212 skip_postprocessing: bool = False, 

213 ns: Optional[ 

214 Union[ 

215 v0_5.ParameterizedSize_N, 

216 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

217 ] 

218 ] = None, 

219 batch_size: Optional[int] = None, 

220 ) -> Tuple[int, Iterable[IntermediatePrediction]]: 

221 """Predict `sample` by predicting sample blocks and yield intermediate predictions if no samplewise postprocessing is included. 

222 

223 Returns: 

224 Tuple of number of blocks and an iterator of predicted intermediate samples with the last predicted block, 

225 All samples, but the last one, are intermediate samples with more and more blocks predicted. 

226 In case samplewise postprocessing needs to be applied, no intermediate results are yielded, but only the final sample after all blocks are predicted and postprocessed. 

227 """ 

228 if isinstance(self._model_descr, v0_4.ModelDescr): 

229 raise NotImplementedError( 

230 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

231 + f" {self._model_descr.name}." 

232 + " Consider using `predict_sample_with_fixed_blocking`" 

233 ) 

234 

235 ns = ns or self._default_blocksize_parameter 

236 if isinstance(ns, int): 

237 ns = { 

238 (ipt.id, a.id): ns 

239 for ipt in self._model_descr.inputs 

240 for a in ipt.axes 

241 if isinstance(a.size, v0_5.ParameterizedSize) 

242 } 

243 input_block_shape = self._model_descr.get_tensor_sizes( 

244 ns, batch_size or self._default_batch_size 

245 ).inputs 

246 

247 return self.predict_sample_with_fixed_blocking_yield_intermediates( 

248 sample, 

249 input_block_shape=input_block_shape, 

250 skip_preprocessing=skip_preprocessing, 

251 skip_postprocessing=skip_postprocessing, 

252 ) 

253 

254 @abstractmethod 

255 def predict_sample_with_fixed_blocking_yield_intermediates( 

256 self, 

257 sample: Sample, 

258 input_block_shape: PerMember[PerAxis[int]], 

259 *, 

260 skip_preprocessing: bool = False, 

261 skip_postprocessing: bool = False, 

262 fill_value: float = float("nan"), 

263 ) -> Tuple[int, Iterable[IntermediatePrediction]]: ... 

264 

265 @abstractmethod 

266 def predict_sample_block( 

267 self, 

268 sample_block: SampleBlock, 

269 skip_preprocessing: bool = False, 

270 skip_postprocessing: bool = False, 

271 ) -> SampleBlock: 

272 """Predict a single sample block. 

273 

274 Note that this does not apply samplewise preprocessing or postprocessing steps, but only blockwise ones. 

275 

276 Args: 

277 sample_block: The sample block to predict on. 

278 skip_preprocessing: If `True`, skip blockwise preprocessing steps. 

279 skip_postprocessing: If `True`, skip blockwise postprocessing steps. 

280 """ 

281 

282 

283class PredictionPipeline(_PredictionPipelineBase): 

284 """ 

285 Represents model computation including preprocessing and postprocessing 

286 Note: Ideally use the `PredictionPipeline` in a with statement 

287 (as a context manager). 

288 """ 

289 

290 def __init__( 

291 self, 

292 *, 

293 name: str, 

294 model_description: AnyModelDescr, 

295 preprocessing: List[Processing], 

296 postprocessing: List[Processing], 

297 model_adapter: ModelAdapter, 

298 default_blocksize_parameter: BlocksizeParameter = 10, 

299 default_batch_size: int = 1, 

300 ) -> None: 

301 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults.""" 

302 super().__init__( 

303 model_descr=model_description, 

304 default_blocksize_parameter=default_blocksize_parameter, 

305 default_batch_size=default_batch_size, 

306 ) 

307 

308 if model_description.run_mode: 

309 warnings.warn( 

310 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 

311 ) 

312 

313 self.name = name 

314 # split preprocessing into samplewise and blockwise. samplewise preprocessing is all preprocessing up to including the last samplewise operator, blockwise preprocessing are the remaining blockwise operators. 

315 # I.e. some samplewise preprocessing may be a blockwise op (at some point followed by a samplewise op). 

316 self._samplewise_preprocessing: List[ 

317 Union[SamplewiseOperator, BlockwiseOperator] 

318 ] = [] 

319 self._blockwise_preprocessing: List[BlockwiseOperator] = [] 

320 for op in preprocessing[::-1]: 

321 if isinstance(op, BlockwiseOperator) and not self._samplewise_preprocessing: 

322 self._blockwise_preprocessing.insert(0, op) 

323 else: 

324 self._samplewise_preprocessing.insert(0, op) 

325 # split postprocessing analougly, but here we start blockwise and switch to samplewise at the first samplewise operator. 

326 self._blockwise_postprocessing: List[BlockwiseOperator] = [] 

327 self._samplewise_postprocessing: List[ 

328 Union[BlockwiseOperator, SamplewiseOperator] 

329 ] = [] 

330 for op in postprocessing: 

331 if ( 

332 isinstance(op, BlockwiseOperator) 

333 and not self._samplewise_postprocessing 

334 ): 

335 self._blockwise_postprocessing.append(op) 

336 else: 

337 self._samplewise_postprocessing.append(op) 

338 

339 self._input_ids = get_member_ids(model_description.inputs) 

340 self._output_ids = get_member_ids(model_description.outputs) 

341 

342 self._adapter = model_adapter 

343 

344 def __enter__(self): 

345 self.load() 

346 return self 

347 

348 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 

349 self.unload() 

350 return False 

351 

352 @property 

353 def has_non_blockwise_preprocessing(self) -> bool: 

354 """`True` if any preprocessing operators in the pipeline are not applicable blockwise.""" 

355 return bool(self._samplewise_preprocessing) 

356 

357 @property 

358 def has_non_blockwise_postprocessing(self) -> bool: 

359 """`True` if any postprocessing operators in the pipeline are not applicable blockwise.""" 

360 return bool(self._samplewise_postprocessing) 

361 

362 def _raise_for_non_blockwise_processing( 

363 self, proc_type: Literal["preprocessing", "postprocessing"] 

364 ): 

365 ops = ( 

366 self._samplewise_preprocessing 

367 if proc_type == "preprocessing" 

368 else self._samplewise_postprocessing 

369 ) 

370 non_blockwise = [ 

371 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator) 

372 ] 

373 if non_blockwise: 

374 raise NotImplementedError( 

375 f"Blockwise {proc_type} for {non_blockwise} not implemented." 

376 ) 

377 

378 def raise_for_non_blockwise_preprocessing(self): 

379 """ 

380 Raises: 

381 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline 

382 """ 

383 self._raise_for_non_blockwise_processing("preprocessing") 

384 

385 def raise_for_non_blockwise_postprocessing(self): 

386 """ 

387 Raises: 

388 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline 

389 """ 

390 self._raise_for_non_blockwise_processing("postprocessing") 

391 

392 def predict_sample_block( 

393 self, 

394 sample_block: SampleBlock, 

395 skip_preprocessing: bool = False, 

396 skip_postprocessing: bool = False, 

397 ) -> SampleBlock: 

398 if isinstance(self._model_descr, v0_4.ModelDescr): 

399 raise NotImplementedError( 

400 f"predict_sample_block not implemented for model {self._model_descr.format_version}" 

401 ) 

402 else: 

403 assert self._block_transform is not None 

404 

405 if not skip_preprocessing: 

406 self._apply_blockwise_preprocessing(sample_block) 

407 

408 output_meta = sample_block.get_transformed_meta(self._block_transform) 

409 local_output = self._adapter.forward(sample_block.members) 

410 

411 output = output_meta.with_data( 

412 {k: v for k, v in local_output.items() if v is not None}, 

413 stat=sample_block.stat, 

414 ) 

415 if not skip_postprocessing: 

416 self._apply_blockwise_postprocessing(output) 

417 

418 return output 

419 

420 def predict_sample_without_blocking( 

421 self, 

422 sample: Sample, 

423 skip_preprocessing: bool = False, 

424 skip_postprocessing: bool = False, 

425 skip_input_padding: bool = False, 

426 skip_output_cropping: bool = False, 

427 ) -> Sample: 

428 if not skip_input_padding: 

429 sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode) 

430 

431 if not skip_preprocessing: 

432 self.apply_preprocessing(sample) 

433 

434 output = Sample( 

435 members={ 

436 k: v 

437 for k, v in self._adapter.forward(sample.members).items() 

438 if v is not None 

439 }, 

440 stat=sample.stat, 

441 id=sample.id, 

442 ) 

443 if not skip_postprocessing: 

444 self.apply_postprocessing(output) 

445 

446 if not skip_output_cropping: 

447 output.members = { 

448 m: t 

449 if m not in self._default_output_halo 

450 else t[ 

451 { 

452 a: slice(h.left, None if h.right == 0 else -h.right) 

453 for a, h in self._default_output_halo[m].items() 

454 } 

455 ] 

456 for m, t in output.members.items() 

457 } 

458 

459 return output 

460 

461 def get_output_sample_id(self, input_sample_id: SampleId): 

462 warnings.warn( 

463 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 

464 + " removed soon. Output sample id is equal to input sample id, hence this" 

465 + " function is not needed." 

466 ) 

467 return input_sample_id 

468 

469 def predict_sample_with_blocking( 

470 self, 

471 sample: Sample, 

472 skip_preprocessing: bool = False, 

473 skip_postprocessing: bool = False, 

474 ns: Optional[ 

475 Union[ 

476 v0_5.ParameterizedSize_N, 

477 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 

478 ] 

479 ] = None, 

480 batch_size: Optional[int] = None, 

481 ) -> Sample: 

482 output = None 

483 for output in self.predict_sample_with_blocking_yield_intermediates( 

484 sample, 

485 skip_preprocessing=skip_preprocessing, 

486 skip_postprocessing=skip_postprocessing, 

487 ns=ns, 

488 batch_size=batch_size, 

489 )[1]: 

490 pass 

491 

492 assert output is not None, ( 

493 "No blocks were predicted, cannot return final sample." 

494 ) 

495 return output.sample 

496 

497 def predict_sample_with_fixed_blocking_yield_intermediates( 

498 self, 

499 sample: Sample, 

500 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 

501 *, 

502 skip_preprocessing: bool = False, 

503 skip_postprocessing: bool = False, 

504 fill_value: float = float("nan"), 

505 ) -> Tuple[int, Iterable[IntermediatePrediction]]: 

506 """Predict `sample` with given `input_block_shape` and yield the full sample with intermediate results. 

507 

508 Note: 

509 - `input_block_shape` is expected to be a valid input shape for the model. 

510 - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters 

511 rather than fixed block shapes. 

512 - Postprocessing may only be complete for the final sample (if samplewise postprocessing steps are included 

513 in the pipeline), intermediate samples may have some (blockwise applicable) postprocessing steps applied. 

514 

515 Args: 

516 sample: The sample to predict on. 

517 input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis. 

518 skip_preprocessing: If `True`, skip all preprocessing steps. 

519 skip_postprocessing: If `True`, skip all postprocessing steps. 

520 

521 Returns: 

522 Tuple of number of blocks and an iterable of predicted intermediate samples with the last predicted block, 

523 All samples, but the last one, are intermediate samples with more and more blocks predicted. 

524 """ 

525 

526 if not skip_preprocessing: 

527 self._apply_samplewise_preprocessing(sample) 

528 

529 n_blocks, input_blocks = sample.split_into_blocks( 

530 input_block_shape, 

531 halo=self._default_input_halo, 

532 pad_mode=self.pad_mode, 

533 ) 

534 logger.info( 

535 "split sample shape {} into {} blocks of {}.", 

536 {k: dict(v) for k, v in sample.shape.items()}, 

537 n_blocks, 

538 {k: dict(v) for k, v in input_block_shape.items()}, 

539 ) 

540 

541 def _predict_blocks(): 

542 predicted_sample = None 

543 for i, b in enumerate( 

544 tqdm( 

545 input_blocks, 

546 desc=f"predict sample {sample.id or ''} with {self._model_descr.id or self._model_descr.name}", 

547 unit="block", 

548 unit_divisor=1, 

549 total=n_blocks, 

550 ) 

551 ): 

552 if not skip_preprocessing: 

553 self._apply_blockwise_preprocessing(b) 

554 

555 predicted_block = self.predict_sample_block( 

556 b, skip_preprocessing=True, skip_postprocessing=True 

557 ) 

558 

559 if not skip_postprocessing: 

560 self._apply_blockwise_postprocessing(predicted_block) 

561 

562 if predicted_sample is None: 

563 predicted_sample = Sample.from_blocks( 

564 [predicted_block], fill_value=fill_value 

565 ) 

566 else: 

567 predicted_sample.set_block(predicted_block) 

568 

569 if not skip_postprocessing and i == n_blocks - 1: 

570 self._apply_samplewise_postprocessing(predicted_sample) 

571 

572 yield IntermediatePrediction(predicted_sample, predicted_block) 

573 

574 return n_blocks, _predict_blocks() 

575 

576 def _apply_samplewise_preprocessing(self, sample: Sample, /) -> None: 

577 """Apply preprocessing operators up to and including the last samplewise operator in-place. 

578 

579 Note: This skips all blockwise preprocessing steps after the last samplewise operator. 

580 """ 

581 if isinstance(sample, SampleBlock): 

582 self.raise_for_non_blockwise_preprocessing() 

583 

584 for op in self._samplewise_preprocessing: 

585 op(sample) 

586 

587 def _apply_blockwise_preprocessing( 

588 self, sample_block: Union[Sample, SampleBlock], / 

589 ) -> None: 

590 """Apply blockwise preprocessing operators in-place. 

591 

592 Note: This skips all preprocessing operators up to and including the last samplewise one. 

593 """ 

594 for op in self._blockwise_preprocessing: 

595 op(sample_block) 

596 

597 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

598 """Apply preprocessing in-place, also may updates sample stats""" 

599 

600 if isinstance(sample, Sample): 

601 self._apply_samplewise_preprocessing(sample) 

602 else: 

603 self.raise_for_non_blockwise_preprocessing() 

604 

605 self._apply_blockwise_preprocessing(sample) 

606 

607 def _apply_blockwise_postprocessing( 

608 self, sample_block: Union[Sample, SampleBlock], / 

609 ) -> None: 

610 """Apply in-place blockwise postprocessing operators 

611 

612 Note: This does not apply all postprocessing operators from the first samplewise one onwards. 

613 """ 

614 for op in self._blockwise_postprocessing: 

615 op(sample_block) 

616 

617 def _apply_samplewise_postprocessing(self, sample: Sample, /) -> None: 

618 """Apply in-place postprocessing operators starting from and including the first samplewise operator. 

619 

620 Note: This skips all blockwise postprocessing steps before the first samplewise one. 

621 """ 

622 if isinstance(sample, SampleBlock): 

623 self.raise_for_non_blockwise_postprocessing() 

624 

625 for op in self._samplewise_postprocessing: 

626 op(sample) 

627 

628 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None: 

629 """apply postprocessing in-place, also may updates samples stats""" 

630 self._apply_blockwise_postprocessing(sample) 

631 if isinstance(sample, Sample): 

632 self._apply_samplewise_postprocessing(sample) 

633 else: 

634 self.raise_for_non_blockwise_postprocessing() 

635 

636 def load(self): 

637 """Prepare prediction pipeline for use. 

638 

639 Reusable model adapters may be loaded and unloaded multiple times, but currently not all model adapters 

640 cleanly unload and reload. 

641 

642 Note: 

643 For some model adapters loading is currently part of the constructor making them unusable after unloading. 

644 """ 

645 self._adapter.load() 

646 

647 def unload(self): 

648 """Free any device memory in use. 

649 

650 Note: 

651 Currently prediction pipeline becomes unusable after unloading.""" 

652 self._adapter.unload() 

653 

654 def close(self): 

655 """Permanently close the prediction pipeline and free any device memory in use. 

656 This makes the prediction pipeline unusable afterwards.""" 

657 self.unload() 

658 

659 

660class RemotePredictionPipeline(_PredictionPipelineBase): 

661 """Abstract base class for fully remote prediction pipelines. 

662 

663 Note: A ("local") `PredictionPipeline` may also use a `RemoteModelAdapter` for remote model inference, but it may 

664 still apply local preprocessing and postprocessing steps. 

665 In contrast, a `RemotePredictionPipeline` is designed for the case where all steps including preprocessing and 

666 postprocessing are performed remotely. 

667 """ 

668 

669 def __init__( 

670 self, 

671 model_descr: AnyModelDescr, 

672 *, 

673 server: str, 

674 default_blocksize_parameter: BlocksizeParameter, 

675 default_batch_size: int, 

676 ) -> None: 

677 super().__init__( 

678 model_descr, 

679 default_blocksize_parameter=default_blocksize_parameter, 

680 default_batch_size=default_batch_size, 

681 ) 

682 self._server = server 

683 

684 @property 

685 def server(self) -> str: 

686 return self._server 

687 

688 

689def create_prediction_pipeline( 

690 bioimageio_model: AnyModelDescr, 

691 *, 

692 devices: Optional[Sequence[str]] = None, 

693 weight_format: Optional[SupportedWeightsFormat] = None, 

694 weights_format: Optional[SupportedWeightsFormat] = None, 

695 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 

696 keep_updating_initial_dataset_statistics: bool = False, 

697 fixed_dataset_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), 

698 model_adapter: Optional[ModelAdapter] = None, 

699 ns: Optional[BlocksizeParameter] = None, 

700 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo 

701 **deprecated_kwargs: Any, 

702) -> PredictionPipeline: 

703 """ 

704 Creates prediction pipeline which includes: 

705 * computation of input statistics 

706 * preprocessing 

707 * model prediction 

708 * computation of output statistics 

709 * postprocessing 

710 

711 Args: 

712 bioimageio_model: A bioimageio model description. 

713 devices: (optional) 

714 weight_format: deprecated in favor of **weights_format** 

715 weights_format: (optional) Use a specific **weights_format** rather than 

716 choosing one automatically. 

717 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be 

718 created to run inference with the **bioimageio_model**. 

719 dataset_for_initial_statistics: (optional) If preprocessing steps require input 

720 dataset statistics, **dataset_for_initial_statistics** allows you to 

721 specifcy a dataset from which these statistics are computed. 

722 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want 

723 to update dataset statistics with each processed sample. 

724 fixed_dataset_statistics: (optional) Precomputed dataset (and optionally sample) statistics. 

725 Any included sample statistics will not be calculated on the fly and it is the callers 

726 responsibility to use samples with the corresponding statistics availble in `sample.stat`. 

727 model_adapter: (optional) Allows you to use a custom **model_adapter** instead 

728 of creating one according to the present/selected **weights_format**. 

729 ns: deprecated in favor of **default_blocksize_parameter** 

730 default_blocksize_parameter: Allows to control the default block size for 

731 blockwise predictions, see `BlocksizeParameter`. 

732 

733 """ 

734 weights_format = weight_format or weights_format 

735 del weight_format 

736 default_blocksize_parameter = ns or default_blocksize_parameter 

737 del ns 

738 if deprecated_kwargs: 

739 warnings.warn( 

740 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 

741 ) 

742 

743 model_adapter = model_adapter or create_model_adapter( 

744 model_description=bioimageio_model, 

745 devices=devices, 

746 weight_format_priority_order=weights_format and (weights_format,), 

747 ) 

748 

749 input_ids = get_member_ids(bioimageio_model.inputs) 

750 

751 def dataset(): 

752 common_stat: Stat = {} 

753 for i, x in enumerate(dataset_for_initial_statistics): 

754 if isinstance(x, Sample): 

755 yield x 

756 else: 

757 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 

758 

759 preprocessing, postprocessing = setup_pre_and_postprocessing( 

760 bioimageio_model, 

761 dataset(), 

762 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 

763 fixed_dataset_stats=fixed_dataset_statistics, 

764 ) 

765 

766 return PredictionPipeline( 

767 name=bioimageio_model.name, 

768 model_description=bioimageio_model, 

769 model_adapter=model_adapter, 

770 preprocessing=preprocessing, 

771 postprocessing=postprocessing, 

772 default_blocksize_parameter=default_blocksize_parameter, 

773 ) 

774 

775 

776def create_remote_prediction_pipeline( 

777 model_description: AnyModelDescr, 

778 *, 

779 server: Optional[str] = None, 

780 server_type: Optional[Literal["gradio"]] = "gradio", 

781 precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), 

782 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo 

783 default_batch_size: int = 1, 

784) -> RemotePredictionPipeline: 

785 """Create a `RemotePredictionPipeline` for the given `model_description`. 

786 

787 Args: 

788 model_description: The model to run inference with. 

789 server: The URL or Hugging Face space name of a running bioimageio server instance 

790 server_type: The type of the remote server to connect to. Currently only "gradio" is supported. 

791 precomputed_statistics: Precomputed dataset (and optionally sample) statistics. 

792 Any included sample statistics will not be calculated on the fly and it is the callers 

793 responsibility to use samples with the corresponding statistics availble in `sample.stat`. 

794 default_blocksize_parameter: Allows to control the default block size with a single parameter for blockwise predictions. (not all models support this) 

795 default_batch_size: Default batch size to use 

796 """ 

797 

798 if server_type is None: 

799 server_type = "gradio" 

800 

801 try: 

802 if server_type == "gradio": 

803 from .remote_backends.gradio.client import ( 

804 GradioPredictionPipeline as RemotePredictionPipelineImpl, 

805 ) 

806 else: 

807 assert_never(server_type) 

808 except ImportError as e: 

809 raise ImportError( 

810 f"Failed to import {server_type.capitalize()}PredictionPipeline. Make sure to install the '{server_type}-client' extra," 

811 + f" e.g. with `pip install bioimageio.core[{server_type}-client]`." 

812 ) from e 

813 

814 return RemotePredictionPipelineImpl( 

815 model_description, 

816 server=server, 

817 precomputed_statistics=precomputed_statistics, 

818 default_blocksize_parameter=default_blocksize_parameter, 

819 default_batch_size=default_batch_size, 

820 )