From 7a3269803733b4b8f04f6bc6eee86f4524574901 Mon Sep 17 00:00:00 2001 From: Adam Gayoso Date: Wed, 6 Jun 2018 21:05:33 -0400 Subject: [PATCH] Fix masked labels filling for ambiguous cells (#97) * Fix labels filling for ambiguous cells Filling accidentally assigned a true value to cells that were never clustered. This properly assigns these cells a np.nan label * Correct np.ma.filling and np.nan handling in plot.py. Co-authored-by: Jonathan Shor <11262246+JonathanShor@users.noreply.github.com> Co-authored-by: Adam Gayoso --- src/doubletdetection/doubletdetection.py | 2 +- src/doubletdetection/plot.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/doubletdetection/doubletdetection.py b/src/doubletdetection/doubletdetection.py index 2a16f1f..3bd009e 100644 --- a/src/doubletdetection/doubletdetection.py +++ b/src/doubletdetection/doubletdetection.py @@ -228,7 +228,7 @@ def predict(self, p_thresh=0.99, voter_thresh=0.9): with np.errstate(invalid='ignore'): # Silence numpy warning about NaN comparison self.voting_average_ = np.mean(np.ma.masked_invalid(self.all_p_values_) >= p_thresh, axis=0) - self.labels_ = np.ma.filled(self.voting_average_ >= voter_thresh, np.nan) + self.labels_ = np.ma.filled((self.voting_average_ >= voter_thresh).astype(float), np.nan) self.voting_average_ = np.ma.filled(self.voting_average_, np.nan) else: # Find a cutoff score diff --git a/src/doubletdetection/plot.py b/src/doubletdetection/plot.py index 8884f10..3bf8b44 100644 --- a/src/doubletdetection/plot.py +++ b/src/doubletdetection/plot.py @@ -58,10 +58,9 @@ def convergence(clf, show=False, save=None, p_thresh=0.99, voter_thresh=0.9): with np.errstate(invalid='ignore'): for i in range(clf.n_iters): cum_p_values = clf.all_p_values_[:i + 1] - cum_vote_average = np.mean( - np.ma.masked_invalid(cum_p_values) > p_thresh, axis=0) - cum_doublets = np.ma.filled(cum_vote_average >= voter_thresh, np.nan) - doubs_per_run.append(np.sum(cum_doublets)) + cum_vote_average = np.mean(np.ma.masked_invalid(cum_p_values) > p_thresh, axis=0) + cum_doublets = np.ma.filled((cum_vote_average >= voter_thresh).astype(float), np.nan) + doubs_per_run.append(np.nansum(cum_doublets)) # Ignore warning for convergence plot with warnings.catch_warnings(): @@ -111,19 +110,21 @@ def tsne(raw_counts, labels, n_components=30, n_jobs=-1, show=False, save=None, svd_solver='randomized').fit_transform(norm_counts) communities, _, _ = phenograph.cluster(reduced_counts) tsne_counts = TSNE(n_jobs=-1).fit_transform(reduced_counts) + # Ensure only looking at positively identified doublets + doublets = labels == 1 fig, axes = plt.subplots(1, 1, figsize=(3, 3), dpi=200) axes.scatter(tsne_counts[:, 0], tsne_counts[:, 1], c=communities, cmap=plt.cm.tab20, s=1) - axes.scatter(tsne_counts[:, 0][labels], tsne_counts[:, 1] - [labels], s=3, edgecolor='k', facecolor='k') + axes.scatter(tsne_counts[:, 0][doublets], tsne_counts[:, 1][doublets], + s=3, edgecolor='k', facecolor='k') axes.set_title('Cells with Detected\n Doublets in Black') plt.xticks([]) plt.yticks([]) axes.set_xlabel('{} doublets out of {} cells.\n {}% across-type doublet rate.'.format( - np.sum(labels), + np.sum(doublets), raw_counts.shape[0], - np.round(100 * np.sum(labels) / raw_counts.shape[0], 2))) + np.round(100 * np.sum(doublets) / raw_counts.shape[0], 2))) if show is True: plt.show()