Skip to content

Commit

Permalink
Fix masked labels filling for ambiguous cells (#97)
Browse files Browse the repository at this point in the history
* Fix labels filling for ambiguous cells
Filling accidentally assigned a true value to cells that were never clustered. This properly assigns these cells a np.nan label

* Correct np.ma.filling and np.nan handling in plot.py.

Co-authored-by: Jonathan Shor <[email protected]>
Co-authored-by: Adam Gayoso <[email protected]>
  • Loading branch information
adamgayoso and JonathanShor committed Jun 7, 2018
1 parent 9281a09 commit 7a32698
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/doubletdetection/doubletdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def predict(self, p_thresh=0.99, voter_thresh=0.9):
with np.errstate(invalid='ignore'): # Silence numpy warning about NaN comparison
self.voting_average_ = np.mean(np.ma.masked_invalid(self.all_p_values_) >= p_thresh,
axis=0)
self.labels_ = np.ma.filled(self.voting_average_ >= voter_thresh, np.nan)
self.labels_ = np.ma.filled((self.voting_average_ >= voter_thresh).astype(float), np.nan)
self.voting_average_ = np.ma.filled(self.voting_average_, np.nan)
else:
# Find a cutoff score
Expand Down
17 changes: 9 additions & 8 deletions src/doubletdetection/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ def convergence(clf, show=False, save=None, p_thresh=0.99, voter_thresh=0.9):
with np.errstate(invalid='ignore'):
for i in range(clf.n_iters):
cum_p_values = clf.all_p_values_[:i + 1]
cum_vote_average = np.mean(
np.ma.masked_invalid(cum_p_values) > p_thresh, axis=0)
cum_doublets = np.ma.filled(cum_vote_average >= voter_thresh, np.nan)
doubs_per_run.append(np.sum(cum_doublets))
cum_vote_average = np.mean(np.ma.masked_invalid(cum_p_values) > p_thresh, axis=0)
cum_doublets = np.ma.filled((cum_vote_average >= voter_thresh).astype(float), np.nan)
doubs_per_run.append(np.nansum(cum_doublets))

# Ignore warning for convergence plot
with warnings.catch_warnings():
Expand Down Expand Up @@ -111,19 +110,21 @@ def tsne(raw_counts, labels, n_components=30, n_jobs=-1, show=False, save=None,
svd_solver='randomized').fit_transform(norm_counts)
communities, _, _ = phenograph.cluster(reduced_counts)
tsne_counts = TSNE(n_jobs=-1).fit_transform(reduced_counts)
# Ensure only looking at positively identified doublets
doublets = labels == 1

fig, axes = plt.subplots(1, 1, figsize=(3, 3), dpi=200)
axes.scatter(tsne_counts[:, 0], tsne_counts[:, 1],
c=communities, cmap=plt.cm.tab20, s=1)
axes.scatter(tsne_counts[:, 0][labels], tsne_counts[:, 1]
[labels], s=3, edgecolor='k', facecolor='k')
axes.scatter(tsne_counts[:, 0][doublets], tsne_counts[:, 1][doublets],
s=3, edgecolor='k', facecolor='k')
axes.set_title('Cells with Detected\n Doublets in Black')
plt.xticks([])
plt.yticks([])
axes.set_xlabel('{} doublets out of {} cells.\n {}% across-type doublet rate.'.format(
np.sum(labels),
np.sum(doublets),
raw_counts.shape[0],
np.round(100 * np.sum(labels) / raw_counts.shape[0], 2)))
np.round(100 * np.sum(doublets) / raw_counts.shape[0], 2)))

if show is True:
plt.show()
Expand Down

0 comments on commit 7a32698

Please sign in to comment.