diff --git a/musical/preprocessing.py b/musical/preprocessing.py index 3738170..74000f9 100644 --- a/musical/preprocessing.py +++ b/musical/preprocessing.py @@ -7,58 +7,117 @@ from .cluster import OptimalK, hierarchical_cluster -def gini(x): - mad = np.abs(np.subtract.outer(x, x)).mean() - rmad = mad/np.mean(x) - g = 0.5 * rmad - return g - -def remove_samples_based_on_gini(H, X, gini_baseline = 0.65, gini_delta = 0.005, per_signature = True): - gini_vec = [] - gini_vec = np.array(gini_vec) - for h in H: - h_norm = h/np.sum(X, axis = 0) - gini_this = gini(h_norm) - gini_vec = np.append(gini_vec, gini_this) - - inds_columns_to_check = np.where(gini_vec > gini_baseline) - - indices_to_remove = [] - indices_to_remove = np.array(inds_columns_to_check) - - list_indices_to_keep = {} -# for i in np.array(inds_columns_to_check).tolist()[0]: - for i in inds_columns_to_check[0]: - h_norm = H[i,:]/np.sum(X, axis = 0) - index = h_norm.size - delta = 1 - while delta > gini_delta: - gini_this = gini(np.sort(h_norm)[1:index]) - gini_bef = gini(np.sort(h_norm)[1:(index - 1)]) - delta = gini_this - gini_bef - index = index - 1 - if index < np.around(h_norm.size * 0.8): - break - to_keep = np.where(h_norm < np.sort(h_norm)[index]) - to_remove = np.where(h_norm >= np.sort(h_norm)[index]) - list_indices_to_keep[i] = to_keep - indices_to_remove = np.append(indices_to_remove, to_remove) - - indices_to_remove = np.unique(indices_to_remove) - - if per_signature: - list_X = {} - indices_to_keep_all = np.array(range(0, H.shape[1])) - for i in inds_columns_to_check[0]: - X_this = X[:,list_indices_to_keep[i][0]] - list_X[i] = X_this - indices_this = list_indices_to_keep[i][0] - indices_to_keep_all = np.array([i for i in indices_to_keep_all if i in indices_this]) - return(list_X, inds_columns_to_check, list_indices_to_keep, indices_to_keep_all) - else: - X_this = np.delete(X, indices_to_remove, axis = 1) - indices_to_keep = np.array([i for i in range(0, index) if i not in indices_to_remove]) - return(X_this, indices_to_keep) +def sort_with_indices(x): + """ + x has to be a numpy array + """ + indices_sorted = np.argsort(x) + + return x[indices_sorted], indices_sorted + + +def gini(x, assume_sorted=True): + """ + x is assumed to be increasing + """ + if not assume_sorted: + x = np.sort(x) + n = len(x) + aux = [xi * (2 * i - n + 1) for i, xi in enumerate(x)] + scaling = 1 / (n * sum(x)) + + return scaling * sum(aux) + + +def n_remove_gini(x, gini_delta, thresh): + """ + Identify how many of the largest values of x significantly contribute to its Gini coefficient. + Do not identify more than (1 - thresh) * 100 % of x. + x is assumed to be increasing. + """ + n_remove = 0 + max_n_remove = np.round((1 - thresh) * len(x)) + + gini_old = gini(x) + gini_new = gini(x[:-1]) + n_remove = 0 + + while gini_old - gini_new > gini_delta and n_remove < max_n_remove: + + n_remove += 1 + gini_old = gini_new + gini_new = gini(x[: - n_remove - 1]) + + return n_remove + + +def remove_samples_based_on_gini(H, X, gini_baseline=.65, gini_delta=.005): + """ + Identify signatures with unequal exposures. A signature is said to have unequal exposures if the + Gini coefficient of the sample exposures is higher than a given threshold. + For these signatures, the samples causing the gini coefficient to be high are also identified. + + Input: + ------ + H: np.ndarray + The exposure matrix of shape (n_signatures, n_samples) + + X: np.ndarray + The mutation count matrix of shape (n_features, n_samples) + + gini_baseline: float + Signatures with exposures having a higher Gini coefficient than 'gini_baseline' are identified + as having unequal exposures + + gini_delta: float + Per signature with unequal exposure, a sample is identified as a sample significanlty contributing + the high Gini coefficient if removing it decreases the Gini coefficient by at least 'gini_delta' + + Output: + ------ + samples_to_keep: dict + keys: indices of signatures with unequal exposures + values: corresponding sample indices that do not (!) cause the Gini coefficient to be high + + X_to_keep: dict + keys: indices of signatures with unequal exposures + values: mutation count matix subsetted to the samples that do not (!) cause the Gini coefficient to be high + + samples_to_keep_all: np.ndarray + List of sample indices not significantly causing the Gini coefficient of any signature with unequal exposure to be high + """ + H, X = np.array(H), np.array(X) + + n_samples = H.shape[1] + + # normalize the exposures + H = H /np.sum(X, axis=0) + + # Gini coefficients of normalized signature exposures + gini_coeffs = np.array([gini(np.sort(h)) for h in H]) + sigs_to_check = np.where(gini_coeffs > gini_baseline)[0] + + samples_to_keep = {} + samples_to_remove = set() + + for sig_index in sigs_to_check: + + sorted_h, sorted_h_indices = sort_with_indices(H[sig_index,:]) + n_remove = n_remove_gini(sorted_h, gini_delta, .8) + + to_keep, to_remove = np.split(sorted_h_indices, [-n_remove]) if n_remove else (sorted_h_indices, np.empty(0)) + samples_to_keep[sig_index] = np.array(sorted(to_keep)) + samples_to_remove |= set(to_remove) + + X_to_keep = {sig_index: X[:, samples] for sig_index, samples in samples_to_keep.items()} + + samples_to_keep_all = set(range(n_samples)) - samples_to_remove + samples_to_keep_all = np.array(sorted(samples_to_keep_all)) + + results = (samples_to_keep, X_to_keep, samples_to_keep_all) + + return results + def identify_distinct_cluster(X, H, frac_thresh=0.05): """Identify distinct clusters from the cohort based on exposures. @@ -256,12 +315,12 @@ def stratify_samples(X, H=None, sil_thresh=0.9, i.e., at most 1 cluster with silhouette score < sil_thresh, to accept the full clustering. Otherwise a smaller k might be more appropriate. This is a bit more complicated. So we ignore this for now. """ - + if H is None: data = normalize(X, norm='l1', axis=0) else: data = normalize(H, norm='l1', axis=0) - + n_samples = data.shape[1] # Clustering with automatic selection of cluster number optimalK = OptimalK(data, max_k=max_k, nrefs=nrefs, metric=metric, linkage_method=linkage_method, ref_method=ref_method) @@ -270,7 +329,7 @@ def stratify_samples(X, H=None, sil_thresh=0.9, if k > 1: # If k > 1, we check per-cluster silhouette scores. # If at least one cluster has silhouette score > sil_thresh, we accept the clustering. - # Otherwise we reject it. + # Otherwise we reject it. if np.any(optimalK.silscorek_percluster[k] > sil_thresh): cluster_membership = optimalK.cluster_membership clusters = []