Coverage for src / bioimageio / core / _ops_cellpose.py: 86%

57 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from dataclasses import dataclass 

2from typing import Any, Collection 

3 

4import numpy as np 

5from numpy.typing import NDArray 

6from typing_extensions import Literal, cast 

7 

8from bioimageio.spec.model.v0_5 import CellposeFlowDynamicsDescr 

9 

10from ._op_base import SamplewiseOperator 

11from .axis import AxisId, PerAxis 

12from .common import MemberId 

13from .sample import Sample 

14from .stat_measures import Measure 

15from .tensor import Tensor 

16 

17 

18@dataclass 

19class CellposeFlowDynamics(SamplewiseOperator): 

20 """Cellpose flow-dynamics postprocessing operator. 

21 

22 Adds `output` member to the sample, containing instance labels (int32, 0 = background) 

23 decoded from the flow fields and cell probability output of a Cellpose model. 

24 

25 """ 

26 

27 cellprob_threshold: float = 0.0 

28 flow_threshold: float = 0.4 

29 do_3D: bool = False 

30 min_size: int = 15 

31 """Minimum size of objects to keep, in pixels. Default is 15, which is the default in Cellpose. Set to 0 to disable filtering by size.""" 

32 labels_id: MemberId = MemberId("labels") 

33 output_dtype: Literal["uint16", "uint32"] = "uint16" 

34 

35 @classmethod 

36 def from_proc_descr( 

37 cls, proc_descr: CellposeFlowDynamicsDescr, member_id: MemberId 

38 ) -> "CellposeFlowDynamics": 

39 kwargs = proc_descr.kwargs 

40 return cls( 

41 labels_id=member_id, 

42 cellprob_threshold=kwargs.cellprob_threshold, 

43 flow_threshold=kwargs.flow_threshold, 

44 do_3D=kwargs.do_3D, 

45 output_dtype=kwargs.output_dtype, 

46 min_size=kwargs.min_size, 

47 ) 

48 

49 @property 

50 def required_measures(self) -> Collection[Measure]: 

51 return set() 

52 

53 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: 

54 output_shape = dict(input_shape) 

55 output_shape[AxisId("channel")] = 1 

56 return output_shape 

57 

58 def __call__(self, sample: Sample) -> None: 

59 input_tensor = sample.members[self.labels_id] 

60 output_tensor = self._apply(input_tensor) 

61 sample.members[self.labels_id] = output_tensor 

62 

63 def _apply(self, x: Tensor) -> Tensor: 

64 if x.dims[0] != AxisId("batch"): 

65 raise ValueError( 

66 "Expected first axis to be 'batch' for cellpose flow dynamics." 

67 ) 

68 

69 if x.dims[1] != AxisId("channel"): 

70 raise ValueError( 

71 "Expected first axis to be 'channel' with 3 channels for cellpose flow dynamics." 

72 ) 

73 if x.shape[1] != 3: 

74 raise ValueError( 

75 "Expected 3 stacked tensors along first 'channel' axis: flow_y, flow_x, cellprob for cellpose flow dynamics." 

76 ) 

77 

78 masks = [self._apply_impl(xx) for xx in x] 

79 return Tensor.from_numpy(np.stack(masks, axis=0), dims=x.dims) 

80 

81 def _apply_impl(self, x: Tensor) -> NDArray[Any]: 

82 """apply on a tensor without batch dimension""" 

83 *flows, cellprob = x.to_numpy() 

84 try: 

85 from cellpose.dynamics import ( # pyright: ignore[reportMissingTypeStubs] 

86 compute_masks, # pyright: ignore[reportUnknownVariableType] 

87 ) 

88 from cellpose.utils import ( # pyright: ignore[reportMissingTypeStubs] 

89 fill_holes_and_remove_small_masks, # pyright: ignore[reportUnknownVariableType] 

90 ) 

91 except ImportError as e: 

92 raise ImportError( 

93 "cellpose is required for cellpose_flow_dynamics. Install with: pip install cellpose" 

94 ) from e 

95 

96 flows = np.stack(flows, axis=0) 

97 mask = cast( 

98 NDArray[Any], 

99 compute_masks( 

100 flows, 

101 cellprob, 

102 cellprob_threshold=self.cellprob_threshold, 

103 flow_threshold=self.flow_threshold, 

104 do_3D=self.do_3D, 

105 ), 

106 ) 

107 mask = fill_holes_and_remove_small_masks(mask, min_size=self.min_size) 

108 

109 # add singleton channel axis for output to keep dims consistent with postprocessing input 

110 mask = mask[None] 

111 return mask.astype(np.dtype(self.output_dtype))