Skip to content

Commit

Permalink
Minor change to allow wider use of gini()
Browse files Browse the repository at this point in the history
  • Loading branch information
Hu-JIN committed Feb 7, 2023
1 parent 199e68b commit 01b78e5
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions musical/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def sort_with_indices(x):
return x[indices_sorted], indices_sorted


def gini(x):
def gini(x, assume_sorted=True):
"""
x is assumed to be increasing
"""
if not assume_sorted:
x = np.sort(x)
n = len(x)
aux = [xi * (2 * i - n + 1) for i, xi in enumerate(x)]
scaling = 1 / (n * sum(x))
Expand All @@ -39,7 +41,7 @@ def n_remove_gini(x, gini_delta, thresh):
gini_old = gini(x)
gini_new = gini(x[:-1])
n_remove = 0

while gini_old - gini_new > gini_delta and n_remove < max_n_remove:

n_remove += 1
Expand Down Expand Up @@ -68,7 +70,7 @@ def remove_samples_based_on_gini(H, X, gini_baseline=.65, gini_delta=.005):
as having unequal exposures
gini_delta: float
Per signature with unequal exposure, a sample is identified as a sample significanlty contributing
Per signature with unequal exposure, a sample is identified as a sample significanlty contributing
the high Gini coefficient if removing it decreases the Gini coefficient by at least 'gini_delta'
Output:
Expand All @@ -92,7 +94,7 @@ def remove_samples_based_on_gini(H, X, gini_baseline=.65, gini_delta=.005):
H = H /np.sum(X, axis=0)

# Gini coefficients of normalized signature exposures
gini_coeffs = np.array([gini(sorted(h)) for h in H])
gini_coeffs = np.array([gini(np.sort(h)) for h in H])
sigs_to_check = np.where(gini_coeffs > gini_baseline)[0]

samples_to_keep = {}
Expand All @@ -106,7 +108,7 @@ def remove_samples_based_on_gini(H, X, gini_baseline=.65, gini_delta=.005):
to_keep, to_remove = np.split(sorted_h_indices, [-n_remove]) if n_remove else (sorted_h_indices, np.empty(0))
samples_to_keep[sig_index] = np.array(sorted(to_keep))
samples_to_remove |= set(to_remove)

X_to_keep = {sig_index: X[:, samples] for sig_index, samples in samples_to_keep.items()}

samples_to_keep_all = set(range(n_samples)) - samples_to_remove
Expand Down Expand Up @@ -313,12 +315,12 @@ def stratify_samples(X, H=None, sil_thresh=0.9,
i.e., at most 1 cluster with silhouette score < sil_thresh, to accept the full clustering. Otherwise a smaller k might be
more appropriate. This is a bit more complicated. So we ignore this for now.
"""

if H is None:
data = normalize(X, norm='l1', axis=0)
else:
data = normalize(H, norm='l1', axis=0)

n_samples = data.shape[1]
# Clustering with automatic selection of cluster number
optimalK = OptimalK(data, max_k=max_k, nrefs=nrefs, metric=metric, linkage_method=linkage_method, ref_method=ref_method)
Expand All @@ -327,7 +329,7 @@ def stratify_samples(X, H=None, sil_thresh=0.9,
if k > 1:
# If k > 1, we check per-cluster silhouette scores.
# If at least one cluster has silhouette score > sil_thresh, we accept the clustering.
# Otherwise we reject it.
# Otherwise we reject it.
if np.any(optimalK.silscorek_percluster[k] > sil_thresh):
cluster_membership = optimalK.cluster_membership
clusters = []
Expand Down

0 comments on commit 01b78e5

Please sign in to comment.