Coverage for src / bioimageio / core / _ops_cellpose.py: 86%
57 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from dataclasses import dataclass
2from typing import Any, Collection
4import numpy as np
5from numpy.typing import NDArray
6from typing_extensions import Literal, cast
8from bioimageio.spec.model.v0_5 import CellposeFlowDynamicsDescr
10from ._op_base import SamplewiseOperator
11from .axis import AxisId, PerAxis
12from .common import MemberId
13from .sample import Sample
14from .stat_measures import Measure
15from .tensor import Tensor
18@dataclass
19class CellposeFlowDynamics(SamplewiseOperator):
20 """Cellpose flow-dynamics postprocessing operator.
22 Adds `output` member to the sample, containing instance labels (int32, 0 = background)
23 decoded from the flow fields and cell probability output of a Cellpose model.
25 """
27 cellprob_threshold: float = 0.0
28 flow_threshold: float = 0.4
29 do_3D: bool = False
30 min_size: int = 15
31 """Minimum size of objects to keep, in pixels. Default is 15, which is the default in Cellpose. Set to 0 to disable filtering by size."""
32 labels_id: MemberId = MemberId("labels")
33 output_dtype: Literal["uint16", "uint32"] = "uint16"
35 @classmethod
36 def from_proc_descr(
37 cls, proc_descr: CellposeFlowDynamicsDescr, member_id: MemberId
38 ) -> "CellposeFlowDynamics":
39 kwargs = proc_descr.kwargs
40 return cls(
41 labels_id=member_id,
42 cellprob_threshold=kwargs.cellprob_threshold,
43 flow_threshold=kwargs.flow_threshold,
44 do_3D=kwargs.do_3D,
45 output_dtype=kwargs.output_dtype,
46 min_size=kwargs.min_size,
47 )
49 @property
50 def required_measures(self) -> Collection[Measure]:
51 return set()
53 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]:
54 output_shape = dict(input_shape)
55 output_shape[AxisId("channel")] = 1
56 return output_shape
58 def __call__(self, sample: Sample) -> None:
59 input_tensor = sample.members[self.labels_id]
60 output_tensor = self._apply(input_tensor)
61 sample.members[self.labels_id] = output_tensor
63 def _apply(self, x: Tensor) -> Tensor:
64 if x.dims[0] != AxisId("batch"):
65 raise ValueError(
66 "Expected first axis to be 'batch' for cellpose flow dynamics."
67 )
69 if x.dims[1] != AxisId("channel"):
70 raise ValueError(
71 "Expected first axis to be 'channel' with 3 channels for cellpose flow dynamics."
72 )
73 if x.shape[1] != 3:
74 raise ValueError(
75 "Expected 3 stacked tensors along first 'channel' axis: flow_y, flow_x, cellprob for cellpose flow dynamics."
76 )
78 masks = [self._apply_impl(xx) for xx in x]
79 return Tensor.from_numpy(np.stack(masks, axis=0), dims=x.dims)
81 def _apply_impl(self, x: Tensor) -> NDArray[Any]:
82 """apply on a tensor without batch dimension"""
83 *flows, cellprob = x.to_numpy()
84 try:
85 from cellpose.dynamics import ( # pyright: ignore[reportMissingTypeStubs]
86 compute_masks, # pyright: ignore[reportUnknownVariableType]
87 )
88 from cellpose.utils import ( # pyright: ignore[reportMissingTypeStubs]
89 fill_holes_and_remove_small_masks, # pyright: ignore[reportUnknownVariableType]
90 )
91 except ImportError as e:
92 raise ImportError(
93 "cellpose is required for cellpose_flow_dynamics. Install with: pip install cellpose"
94 ) from e
96 flows = np.stack(flows, axis=0)
97 mask = cast(
98 NDArray[Any],
99 compute_masks(
100 flows,
101 cellprob,
102 cellprob_threshold=self.cellprob_threshold,
103 flow_threshold=self.flow_threshold,
104 do_3D=self.do_3D,
105 ),
106 )
107 mask = fill_holes_and_remove_small_masks(mask, min_size=self.min_size)
109 # add singleton channel axis for output to keep dims consistent with postprocessing input
110 mask = mask[None]
111 return mask.astype(np.dtype(self.output_dtype))