Skip to content

Commit

Permalink
feat: Support the Cellpose2 workflow (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
timwiggin authored Feb 8, 2024
1 parent 06df04a commit 4a2d030
Show file tree
Hide file tree
Showing 9 changed files with 1,697 additions and 1,311 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.5.1
version: 1.7.0
virtualenvs-create: true
virtualenvs-in-project: true

Expand Down Expand Up @@ -80,7 +80,7 @@ jobs:
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
version: 1.5.1
version: 1.7.0
virtualenvs-create: true
virtualenvs-in-project: true

Expand Down Expand Up @@ -137,7 +137,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.5.1
version: 1.7.0
virtualenvs-create: true
virtualenvs-in-project: true

Expand Down
2,860 changes: 1,597 additions & 1,263 deletions poetry.lock

Large diffs are not rendered by default.

28 changes: 14 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@ packages = [{ include = "vpt_core*", from = "src" }]

[tool.poetry.dependencies]
python = ">=3.9,<3.11"
boto3 = "1.17"
fsspec = "2021.10.0"
geopandas = "0.12.1"
gcsfs = "2021.10.0"
numpy = "1.22.4"
opencv-python-headless = "4.6.0.66"
pandas = "1.4.3"
boto3 = ">=1.17"
fsspec = ">=2021.10.0"
geopandas = ">=0.13.2"
gcsfs = ">=2021.10.0"
numpy = "^1.24.3"
opencv-python-headless = ">=4.6.0.66"
pandas = "^2.0.3"
psutil = "*"
pyarrow = "8.0.0"
python-dotenv = "0.20.0"
rasterio = "1.3.0"
s3fs = "2021.10.0"
scikit-image = "0.19.3"
shapely = "2.0.0"
tenacity = "8.2.2"
pyarrow = ">=8.0.0, <14.0.0"
python-dotenv = ">=0.20.0"
rasterio = ">=1.3.0, <1.3.6"
s3fs = ">=2021.10.0"
scikit-image = ">=0.19.3"
shapely = ">=2.0.0"
tenacity = ">=8.2.2"
tqdm = "*"

[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 3 additions & 1 deletion src/vpt_core/io/input_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import geopandas as gpd
from geopandas.geodataframe import DEFAULT_GEO_COLUMN_NAME
from pyarrow import parquet, Table
from shapely import wkb

from vpt_core.io.vzgfs import vzg_open, retrying_attempts
from vpt_core.segmentation.seg_result import SegmentationResult


DEFAULT_GEO_COLUMN_NAME = "geometry"


def pyarrow_table_to_pandas(table: Table):
df = gpd.GeoDataFrame(table.to_pandas(integer_object_nulls=True))
if SegmentationResult.geometry_field in df.columns:
Expand Down
15 changes: 10 additions & 5 deletions src/vpt_core/io/vzgfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from gcsfs import GCSFileSystem
from rasterio.errors import NotGeoreferencedWarning
from s3fs import S3FileSystem
from tenacity import stop_after_delay, stop_after_attempt, retry_if_exception_type, wait_fixed
from tenacity import (
stop_after_delay,
stop_after_attempt,
retry_if_exception_type,
wait_fixed,
)

from vpt_core import (
AWS_ACCESS_KEY_VAR,
Expand Down Expand Up @@ -126,18 +131,18 @@ def retrying_attempts():
)


def vzg_open(uri: str, mode: str):
def vzg_open(uri: str, mode: str, **kwargs):
protocol, path = protocol_path_split(uri)
fs = filesystem_for_protocol(protocol)

assert fs is not None

return fs.open(path, mode)
return fs.open(path, mode, **kwargs)


def io_with_retries(uri: str, mode: str, callback: Callable):
def io_with_retries(uri: str, mode: str, callback: Callable, **kwargs):
for attempt in retrying_attempts():
with attempt, vzg_open(uri, mode) as f:
with attempt, vzg_open(uri, mode, **kwargs) as f:
return callback(f)


Expand Down
5 changes: 5 additions & 0 deletions src/vpt_core/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)


class ResultFields:
detection_id_field: str = "ID"
cell_id_field: str = "EntityID"
Expand Down
68 changes: 44 additions & 24 deletions src/vpt_core/segmentation/seg_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,53 +167,60 @@ def _assign_entity_index(
self, entities_df: GeoDataFrame, entities_ids: List, cur_poly: MultiPolygon, coverage_threshold: float
) -> List[int]:
res = []
geometry_col = self.df.columns.get_loc(self.geometry_field)
id_values = entities_df[self.cell_id_field].values
geoms = entities_df.geometry.values
for i in entities_ids:
prev_poly = entities_df[entities_df[self.cell_id_field] == i].iloc[0, geometry_col]
prev_poly = geoms[id_values == i][0]
polys_intersection = prev_poly.intersection(cur_poly)
if (polys_intersection.area / min(prev_poly.area, cur_poly.area)) > coverage_threshold:
res.append(i)
return res

def _update_entities_index(self, overlaps_pairs: np.ndarray, cur_cells_ids: List, updated_ids: List, z_index: int):
cur_df = self.df.loc[self.df[self.z_index_field] == z_index]
def _update_entities_index(
self, overlaps_pairs: np.ndarray, cur_cells_ids: np.ndarray, cur_df: GeoDataFrame, updated_ids: List
):
id_values = cur_df[self.cell_id_field].values
index_values = cur_df.index
for old_id, new_id in updated_ids:
for j in cur_df.loc[cur_df[self.cell_id_field] == old_id].index:
for j in index_values[np.where(id_values == old_id)]:
self.df.at[j, self.cell_id_field] = new_id
cur_df.at[j, self.cell_id_field] = new_id
for i in [index for index in range(len(cur_cells_ids)) if cur_cells_ids[index] == old_id]:
cur_cells_ids[i] = new_id
cur_cells_ids[cur_cells_ids == old_id] = new_id

if len(overlaps_pairs) == 0:
return
overlaps_pairs[overlaps_pairs[:, 0] == old_id, 0] = new_id

def fuse_across_z(self, coverage_threshold: float = 0.5):
if len(self.df) == 0:
return
z_lines = self.df["ZIndex"].unique()
reserved_ids = []
geometry_col = self.df.columns.get_loc(self.geometry_field)
prev_df = self.df.loc[self.df[self.z_index_field] == z_lines[0]]
for z_i in range(1, len(z_lines)):
prev_df = self.df.loc[self.df[self.z_index_field] == z_lines[z_i - 1]]
prev_df = prev_df.assign(**{self.z_index_field: z_lines[z_i]})
cur_df = self.df.loc[self.df[self.z_index_field] == z_lines[z_i]]

reserved_ids.extend(prev_df[self.cell_id_field].to_list())
cur_cells_ids = cur_df[self.cell_id_field].unique()
cell_ids_values = cur_df[self.cell_id_field].values
cur_cells_ids = np.unique(cell_ids_values)
cur_geoms = cur_df.geometry.values
overlaps = np.array(self.find_overlapping_entities(prev_df, cur_df))

for i in range(len(cur_cells_ids)):
cur_i = cur_cells_ids[i]
cur_intersected_i_to_update = []
need_new_index = True
if len(overlaps) > 0 and cur_i in overlaps[:, 0]:
cur_poly = cur_df[cur_df[self.cell_id_field] == cur_i].iloc[0, geometry_col]
cur_poly = cur_geoms[cell_ids_values == cur_i][0]
prev_ids = [pair[1] for pair in overlaps if pair[0] == cur_i]

prev_ids = self._assign_entity_index(prev_df, prev_ids, cur_poly, coverage_threshold)
if len(prev_ids) > 1:
prev_i_to_update = [(prev_i, prev_ids[0]) for prev_i in prev_ids[1:]]
for z in z_lines[:z_i]:
self._update_entities_index(np.array([]), [], prev_i_to_update, z)
df_to_update = self.df.loc[self.df[self.z_index_field] == z]
self._update_entities_index(np.array([]), np.array([]), df_to_update, prev_i_to_update)

if len(prev_ids) > 0:
prev_i = prev_ids[0]
Expand All @@ -228,19 +235,30 @@ def fuse_across_z(self, coverage_threshold: float = 0.5):
if cur_i in reserved_ids:
new_id = max(max(reserved_ids), max(cur_cells_ids)) + 1
cur_intersected_i_to_update.append((cur_i, new_id))
self._update_entities_index(overlaps, cur_cells_ids, cur_intersected_i_to_update, z_lines[z_i])
cur_df = self.df.loc[self.df[self.z_index_field] == z_lines[z_i]]
self._update_entities_index(overlaps, cur_cells_ids, cur_df, cur_intersected_i_to_update)

prev_df = cur_df

self.group_duplicated_entities()

def group_duplicated_entities(self):
for z, df in self.df.groupby(self.z_index_field):
for entity_id, entity_df in df.groupby(self.cell_id_field):
if len(entity_df) == 1:
continue
poly = convert_to_multipoly(get_valid_geometry(unary_union(entity_df[self.geometry_field])))
self.df.at[entity_df.index[0], self.geometry_field] = poly
self.df.drop(entity_df.index[1:], inplace=True)
duplicated_fields = [self.z_index_field, self.cell_id_field]
duplicated = self.df.duplicated(duplicated_fields, keep=False)
if len(duplicated) == 0:
return

grouped = (
self.df[duplicated]
.groupby(duplicated_fields)[self.geometry_field]
.apply(lambda geoms: convert_to_multipoly(get_valid_geometry(unary_union(geoms))))
)

self.df.drop_duplicates(duplicated_fields, inplace=True)
merged_geoms = self.df[duplicated].apply(
lambda row: grouped[row[self.z_index_field], row[self.cell_id_field]], axis=1
)
if len(merged_geoms) > 0:
self.df[self.geometry_field].update(merged_geoms)

@staticmethod
def _separate_multi_geometries(data: GeoDataFrame) -> GeoDataFrame:
Expand Down Expand Up @@ -314,7 +332,7 @@ def find_overlapping_entities(dataframe_1: GeoDataFrame, dataframe_2: Optional[G

gdf2 = dataframe_2 if dataframe_2 is not None else dataframe_1

overlaps = dataframe_1.sindex.query_bulk(gdf2.geometry, predicate="intersects").T
overlaps = dataframe_1.sindex.query(gdf2.geometry, predicate="intersects").T

# Remove self-intersections
if dataframe_2 is None:
Expand Down Expand Up @@ -350,7 +368,7 @@ def get_z_geoms(self, z_line: int) -> GeoSeries:
return self.df.loc[self.df[self.z_index_field] == z_line, self.geometry_field]

@staticmethod
def combine_segmentations(segmentations: List):
def combine_segmentations(segmentations: List[SegmentationResult]) -> SegmentationResult:
non_empty_segmentations = [seg for seg in segmentations if len(seg.df) > 0]
if len(non_empty_segmentations) > 1:
to_concat = [seg.df for seg in non_empty_segmentations]
Expand All @@ -369,8 +387,10 @@ def combine_segmentations(segmentations: List):
res = SegmentationResult(dataframe=df)
res.set_column(SegmentationResult.detection_id_field, list(range(len(res.df))))
return res
elif len(non_empty_segmentations) == 1:
return non_empty_segmentations[0]
else:
return non_empty_segmentations[0] if len(non_empty_segmentations) > 0 else SegmentationResult()
return segmentations[0] if len(segmentations) > 0 else SegmentationResult()

def _union_entities(self, base_gdf, add_gdf):
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/vpt_core/test_seg_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,23 @@ def test_3d_splitted_difference() -> None:
)

assert_seg_equals(s1, gt_seg, 1)


def test_combine():
empty_df = SegmentationResult.combine_segmentations([])
assert all(empty_df.df.columns == SegmentationResult().df.columns)
assert len(empty_df.df) == 0

empty_df_custom_columns = SegmentationResult()
empty_df_custom_columns.set_column("custom", None)
result = SegmentationResult.combine_segmentations([empty_df_custom_columns])
assert all(result.df.columns == empty_df_custom_columns.df.columns)
assert len(result.df) == 0

seg_results = [
from_shapes_3d([[Square(30, 30, 30), Square(50, 50, 30)]]),
from_shapes_3d([[Square(100, 50, 30)]], cids=[10]),
]
result1 = SegmentationResult.combine_segmentations([*seg_results, SegmentationResult()])
assert len(result1.df) == 3
assert all(seg_results[0].df.columns == result1.df.columns)
2 changes: 1 addition & 1 deletion tests/vpt_core/test_vzgfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def random_throw_timeout(*args, **kwargs):
if "calls" not in random_throw_timeout.__dict__:
random_throw_timeout.calls = 0
random_throw_timeout.calls += 1
probability = [100, 50, 25, 10, 0]
probability = [100, 50, 25, 0]
if random_throw_timeout.calls + 1 > len(probability):
return
if random.uniform(1, 100) < probability[random_throw_timeout.calls - 1]:
Expand Down

0 comments on commit 4a2d030

Please sign in to comment.