Skip to content

Commit

Permalink
Merge pull request #64 from parklab/dev
Browse files Browse the repository at this point in the history
Bug fixes remove_samples_based_on_gini
  • Loading branch information
Hu-JIN authored Feb 7, 2023
2 parents d7038e1 + 01b78e5 commit c371845
Showing 1 changed file with 114 additions and 55 deletions.
169 changes: 114 additions & 55 deletions musical/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,117 @@
from .cluster import OptimalK, hierarchical_cluster


def gini(x):
mad = np.abs(np.subtract.outer(x, x)).mean()
rmad = mad/np.mean(x)
g = 0.5 * rmad
return g

def remove_samples_based_on_gini(H, X, gini_baseline = 0.65, gini_delta = 0.005, per_signature = True):
gini_vec = []
gini_vec = np.array(gini_vec)
for h in H:
h_norm = h/np.sum(X, axis = 0)
gini_this = gini(h_norm)
gini_vec = np.append(gini_vec, gini_this)

inds_columns_to_check = np.where(gini_vec > gini_baseline)

indices_to_remove = []
indices_to_remove = np.array(inds_columns_to_check)

list_indices_to_keep = {}
# for i in np.array(inds_columns_to_check).tolist()[0]:
for i in inds_columns_to_check[0]:
h_norm = H[i,:]/np.sum(X, axis = 0)
index = h_norm.size
delta = 1
while delta > gini_delta:
gini_this = gini(np.sort(h_norm)[1:index])
gini_bef = gini(np.sort(h_norm)[1:(index - 1)])
delta = gini_this - gini_bef
index = index - 1
if index < np.around(h_norm.size * 0.8):
break
to_keep = np.where(h_norm < np.sort(h_norm)[index])
to_remove = np.where(h_norm >= np.sort(h_norm)[index])
list_indices_to_keep[i] = to_keep
indices_to_remove = np.append(indices_to_remove, to_remove)

indices_to_remove = np.unique(indices_to_remove)

if per_signature:
list_X = {}
indices_to_keep_all = np.array(range(0, H.shape[1]))
for i in inds_columns_to_check[0]:
X_this = X[:,list_indices_to_keep[i][0]]
list_X[i] = X_this
indices_this = list_indices_to_keep[i][0]
indices_to_keep_all = np.array([i for i in indices_to_keep_all if i in indices_this])
return(list_X, inds_columns_to_check, list_indices_to_keep, indices_to_keep_all)
else:
X_this = np.delete(X, indices_to_remove, axis = 1)
indices_to_keep = np.array([i for i in range(0, index) if i not in indices_to_remove])
return(X_this, indices_to_keep)
def sort_with_indices(x):
"""
x has to be a numpy array
"""
indices_sorted = np.argsort(x)

return x[indices_sorted], indices_sorted


def gini(x, assume_sorted=True):
"""
x is assumed to be increasing
"""
if not assume_sorted:
x = np.sort(x)
n = len(x)
aux = [xi * (2 * i - n + 1) for i, xi in enumerate(x)]
scaling = 1 / (n * sum(x))

return scaling * sum(aux)


def n_remove_gini(x, gini_delta, thresh):
"""
Identify how many of the largest values of x significantly contribute to its Gini coefficient.
Do not identify more than (1 - thresh) * 100 % of x.
x is assumed to be increasing.
"""
n_remove = 0
max_n_remove = np.round((1 - thresh) * len(x))

gini_old = gini(x)
gini_new = gini(x[:-1])
n_remove = 0

while gini_old - gini_new > gini_delta and n_remove < max_n_remove:

n_remove += 1
gini_old = gini_new
gini_new = gini(x[: - n_remove - 1])

return n_remove


def remove_samples_based_on_gini(H, X, gini_baseline=.65, gini_delta=.005):
"""
Identify signatures with unequal exposures. A signature is said to have unequal exposures if the
Gini coefficient of the sample exposures is higher than a given threshold.
For these signatures, the samples causing the gini coefficient to be high are also identified.
Input:
------
H: np.ndarray
The exposure matrix of shape (n_signatures, n_samples)
X: np.ndarray
The mutation count matrix of shape (n_features, n_samples)
gini_baseline: float
Signatures with exposures having a higher Gini coefficient than 'gini_baseline' are identified
as having unequal exposures
gini_delta: float
Per signature with unequal exposure, a sample is identified as a sample significanlty contributing
the high Gini coefficient if removing it decreases the Gini coefficient by at least 'gini_delta'
Output:
------
samples_to_keep: dict
keys: indices of signatures with unequal exposures
values: corresponding sample indices that do not (!) cause the Gini coefficient to be high
X_to_keep: dict
keys: indices of signatures with unequal exposures
values: mutation count matix subsetted to the samples that do not (!) cause the Gini coefficient to be high
samples_to_keep_all: np.ndarray
List of sample indices not significantly causing the Gini coefficient of any signature with unequal exposure to be high
"""
H, X = np.array(H), np.array(X)

n_samples = H.shape[1]

# normalize the exposures
H = H /np.sum(X, axis=0)

# Gini coefficients of normalized signature exposures
gini_coeffs = np.array([gini(np.sort(h)) for h in H])
sigs_to_check = np.where(gini_coeffs > gini_baseline)[0]

samples_to_keep = {}
samples_to_remove = set()

for sig_index in sigs_to_check:

sorted_h, sorted_h_indices = sort_with_indices(H[sig_index,:])
n_remove = n_remove_gini(sorted_h, gini_delta, .8)

to_keep, to_remove = np.split(sorted_h_indices, [-n_remove]) if n_remove else (sorted_h_indices, np.empty(0))
samples_to_keep[sig_index] = np.array(sorted(to_keep))
samples_to_remove |= set(to_remove)

X_to_keep = {sig_index: X[:, samples] for sig_index, samples in samples_to_keep.items()}

samples_to_keep_all = set(range(n_samples)) - samples_to_remove
samples_to_keep_all = np.array(sorted(samples_to_keep_all))

results = (samples_to_keep, X_to_keep, samples_to_keep_all)

return results


def identify_distinct_cluster(X, H, frac_thresh=0.05):
"""Identify distinct clusters from the cohort based on exposures.
Expand Down Expand Up @@ -256,12 +315,12 @@ def stratify_samples(X, H=None, sil_thresh=0.9,
i.e., at most 1 cluster with silhouette score < sil_thresh, to accept the full clustering. Otherwise a smaller k might be
more appropriate. This is a bit more complicated. So we ignore this for now.
"""

if H is None:
data = normalize(X, norm='l1', axis=0)
else:
data = normalize(H, norm='l1', axis=0)

n_samples = data.shape[1]
# Clustering with automatic selection of cluster number
optimalK = OptimalK(data, max_k=max_k, nrefs=nrefs, metric=metric, linkage_method=linkage_method, ref_method=ref_method)
Expand All @@ -270,7 +329,7 @@ def stratify_samples(X, H=None, sil_thresh=0.9,
if k > 1:
# If k > 1, we check per-cluster silhouette scores.
# If at least one cluster has silhouette score > sil_thresh, we accept the clustering.
# Otherwise we reject it.
# Otherwise we reject it.
if np.any(optimalK.silscorek_percluster[k] > sil_thresh):
cluster_membership = optimalK.cluster_membership
clusters = []
Expand Down

0 comments on commit c371845

Please sign in to comment.