Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask growth breaks for edge cases #66

Open
MeyerBender opened this issue Apr 8, 2024 · 1 comment
Open

Mask growth breaks for edge cases #66

MeyerBender opened this issue Apr 8, 2024 · 1 comment

Comments

@MeyerBender
Copy link

Hi,

while investigating the mask growing method, I have come across some unexpected behavior, which looks incorrect to me. For example, notice how the mask on the left side of the image occupies pixels that overlap with other cells from the original segmentation.

Original image:
image

Image grown by 1px:
image

I have extracted the corresponding code snippets from the CVMask class to create this standalone example for testing:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.morphology import disk, dilation
from scipy.ndimage.morphology import binary_dilation
from sklearn.neighbors import kneighbors_graph
from scipy.spatial.distance import cdist

# adapted from CVMask
def compute_centroids(flatmasks):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]

    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
    centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()

    return centroids
    
# adapted from CVMask
def remove_overlaps_nearest_neighbors(centroids, masks):
        final_masks = np.max(masks, axis = 2)
        collisions = np.nonzero(np.sum(masks > 0, axis = 2) > 1)
        collision_masks = masks[collisions]
        collision_index = np.nonzero(collision_masks)
        collision_masks = collision_masks[collision_index]
        collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(columns = {0:"collis_idx", 1:"mask_id"})
        grouped_frame = collision_frame.groupby('collis_idx')
        for collis_idx, group in grouped_frame:
            collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis = 0)
            prevval = final_masks[collis_pos[0,0], collis_pos[0,1]]
            mask_ids = list(group['mask_id'])
            curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
            dists = cdist(curr_centroids, collis_pos)
            closest_mask = mask_ids[np.argmin(dists)]
            final_masks[collis_pos[0,0], collis_pos[0,1]] = closest_mask
        
        return final_masks

# adapted from CVMask
def grow_masks(flatmasks, centroids, growth, method = 'Standard', num_neighbors = 30):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1

    # only looking at the standard method, but sequential also appears to have some issues
    if method == 'Standard':
        print("Standard growth selected")
        masks = flatmasks
        num_masks = len(np.unique(masks)) - 1
        indices = np.where(masks != 0)
        values = masks[indices[0], indices[1]]

        maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
        cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
        connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
        connectivity_matrix = connectivity_matrix.astype(int)
        labels = {}
        for n in range(num_masks):
            connections = list(connectivity_matrix[n, :])
            connections.remove(0)
            layers_used = [labels[i] for i in connections if i in labels]
            layers_used.sort()
            currlayer = 0
            for layer in layers_used:
                if currlayer != layer: 
                    break
                currlayer += 1
            labels[n + 1] = currlayer

        possible_layers = len(list(set(labels.values())))
        label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
        image_h, image_w = masks.shape
        expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)

        grouped_frame = label_frame.groupby('layer')
        for layer, group in grouped_frame:
            currids = list(group['maskid'])
            masklocs = np.isin(masks, currids)
            expanded_masks[masklocs, layer] = masks[masklocs]

        dilation_mask = disk(1)
        grown_masks = np.copy(expanded_masks)
        for _ in range(growth):
            for i in range(possible_layers):
                grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
        return remove_overlaps_nearest_neighbors(centroids, grown_masks)
        
example_data = np.array([[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 0, 0, 0, 1, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1],
       [6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 4],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 4, 4],
       [2, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 9, 9, 9],
       [2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
       [0, 0, 0, 0, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 3, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 3, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9]])

centroids = compute_centroids(example_data)
masks_grown = grow_masks(example_data, centroids, 1, method = 'Standard', num_neighbors = 8)
plt.imshow(example_data)
plt.show()
plt.imshow(masks_grown)
plt.show()

I would highly appreciate if you could tell me if I am using this method wrong, or if this is actually a bug within the method. Thank you very much in advance!

@MeyerBender
Copy link
Author

MeyerBender commented Apr 9, 2024

There were two issues with the current methodology which I spotted:

  1. The remove_overlaps method only looked at the grown masks, but not the original ones. This led to unexpected side effects and sometimes even the removal of certain cells (if they were completely engulfed by a grown mask with a higher index).
  2. The remove_overlaps method should be called after each iteration to avoid pixels being dissociated from the parent masks.

Applying these changes to the example above, I get more sensible results.
image

Here is the code I used:

def compute_centroids(flatmasks):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]

    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
    centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()

    return centroids

def remove_overlaps_nearest_neighbors(original_masks, masks, centroids):
    final_masks = np.max(masks, axis=2)
    collisions = np.nonzero(np.sum(masks > 0, axis=2) > 1)
    collision_masks = masks[collisions]
    collision_index = np.nonzero(collision_masks)
    collision_masks = collision_masks[collision_index]
    collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(
        columns={0: "collis_idx", 1: "mask_id"}
    )
    grouped_frame = collision_frame.groupby("collis_idx")
    for collis_idx, group in grouped_frame:
        collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis=0)
        # ALTERED: THIS USED TO ONLY REFER TO THE GROWN MASKS INSTEAD OF THE ORIGINAL ONES, WHICH LED TO UNEXPECTED BEHAVIOR
        mask_ids = list(group["mask_id"])
        curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
        dists = cdist(curr_centroids, collis_pos)
        closest_mask = mask_ids[np.argmin(dists)]
        final_masks[collis_pos[0, 0], collis_pos[0, 1]] = closest_mask
            
    # ALTERED
    # setting all values to the original masks so no masks get overwritten
    # we need: an inverted binary array telling us where there was originally background (in original_masks)
    # multiply this with the final masks and add it to the original masks
    background_pixels = original_masks == 0
    # only reassigning cells which were previously background
    final_masks = np.array(final_masks * background_pixels, dtype=original_masks.dtype)
    # adding this growth to the original masks
    final_masks += original_masks
    return final_masks

def grow_masks(flatmasks, centroids, growth, num_neighbors = 30):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    num_neighbors = min(num_neighbors, num_masks-1)

    # ALTERED: OVERLAPS GET REMOVED AFTER EACH ITERATION TO AVOID PIXELS BEING DISSOCIATED FROM THEIR ORIGINAL MASK
    for _ in range(growth):
        # getting neighboring cells
        indices = np.where(masks != 0)
        values = masks[indices[0], indices[1]]
        maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
        cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
        connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
        connectivity_matrix = connectivity_matrix.astype(int)
        labels = {}
        for n in range(num_masks):
            connections = list(connectivity_matrix[n, :])
            connections.remove(0)
            layers_used = [labels[i] for i in connections if i in labels]
            layers_used.sort()
            currlayer = 0
            for layer in layers_used:
                if currlayer != layer: 
                    break
                currlayer += 1
            labels[n + 1] = currlayer

        possible_layers = len(list(set(labels.values())))
        label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
        image_h, image_w = masks.shape
        expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)

        grouped_frame = label_frame.groupby('layer')
        for layer, group in grouped_frame:
            currids = list(group['maskid'])
            masklocs = np.isin(masks, currids)
            expanded_masks[masklocs, layer] = masks[masklocs]

        dilation_mask = disk(1)
        grown_masks = np.copy(expanded_masks)
        for i in range(possible_layers):
            grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
        masks = remove_overlaps_nearest_neighbors(masks, grown_masks, centroids)
    
    return masks

In my tests, this altered version now performed as I expected it to. Of course you should test it on some of your own examples, but I believe that these changes fix the mask growing (at least the Standard method) and you might want to consider implementing them into the CellSeg codebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant