Skip to content

Commit

Permalink
v.2.3 Log p-values (#101)
Browse files Browse the repository at this point in the history
* add all_log_p_values_ attribute which is used in predict and convergence.
Calculate log survival function instead.
p value thresh now using left tail, so thresh param expects values near zero rather than near one.

* `all_p_values_` deprecated. Will be removed in v3.0.

* add threshold diagnostics plot.
New plot showing number of predicted doublets across a grid of thresholds.
Allows log10 p-values to be shown in threshold plot.

* Update convergence plot to use log p values.

* updated doc with pseudocode to reflect changes.
Also removed downsampling from the doc.

* Update notebook and add threshold plot to notebook.

* Update README.md

* Remove unneeded setup.py imports.

Co-authored-by: Jonathan Shor <[email protected]>
Co-authored-by: Adam Gayoso <[email protected]>
  • Loading branch information
JonathanShor and adamgayoso authored Jun 13, 2018
1 parent 7a32698 commit 0d807fa
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 32 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ clf = doubletdetection.BoostClassifier()
labels = clf.fit(raw_counts).predict()
```

`raw_counts` is a scRNA-seq count matrix (cells by genes), and is array-like. `labels` is a binary 1-dimensional numpy ndarray with the value 1 representing a
detected doublet.
`raw_counts` is a scRNA-seq count matrix (cells by genes), and is array-like. `labels` is a 1-dimensional numpy ndarray with the value 1 representing a detected doublet, 0 a singlet, and `np.nan` an ambiguous cell.

The classifier works best when there are several cell types present in the data. Furthermore, it should be applied individually to each run in an aggregated count matrix.

See our [jupyter notebook](https://nbviewer.jupyter.org/github/JonathanShor/DoubletDetection/blob/master/docs/PBMC_8k_vignette.ipynb) for an example on 8k PBMCs from 10x.

Expand All @@ -30,7 +31,6 @@ Data can be downloaded from the [10x website](https://support.10xgenomics.com/si


## Citations

bioRxiv submission is in the works.

This project is licensed under the terms of the MIT license.
Binary file modified docs/DoubletDetection.pdf
Binary file not shown.
42 changes: 36 additions & 6 deletions docs/PBMC_8k_vignette.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from setuptools import setup
import sys
import shutil
from subprocess import check_call

CLASSIFIERS = [
"Development Status :: 4 - Beta",
Expand All @@ -19,7 +17,7 @@

setup(
name='doubletdetection',
version='2.2.0',
version='2.3.0',
description='Method to detect and enable removal of doublets from single-cell RNA-sequencing '
'data',
url='https://github.com/JonathanShor/DoubletDetection',
Expand Down
42 changes: 25 additions & 17 deletions src/doubletdetection/doubletdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ class BoostClassifier:
pseudocount=new_var)
Attributes:
all_p_values_ (ndarray): Hypergeometric test p-value per cell for cluster
enrichment of synthetic doublets. Shape (n_iters, num_cells).
all_log_p_values_ (ndarray): Hypergeometric test natural log p-value per
cell for cluster enrichment of synthetic doublets. Shape (n_iters,
num_cells).
all_p_values_ (ndarray): DEPRECATED. Exponentiated all_log_p_values.
Due to rounding point errors, use of all_log_p_values recommended.
Will be removed in v3.0.
all_scores_ (ndarray): The fraction of a cell's cluster that is
synthetic doublets. Shape (n_iters, num_cells).
communities_ (ndarray): Cluster ID for corresponding cell. Shape
Expand Down Expand Up @@ -153,8 +157,8 @@ def fit(self, raw_counts):
raw_counts (array-like): Count matrix, oriented cells by genes.
Sets:
all_scores_, all_p_values_, communities_, top_var_genes, parents,
synth_communities
all_scores_, all_p_values_, all_log_p_values_, communities_,
top_var_genes, parents, synth_communities
Returns:
The fitted classifier.
Expand All @@ -181,14 +185,14 @@ def fit(self, raw_counts):
self._normed_raw_counts = self._raw_counts / self._lib_size[:, np.newaxis]

self.all_scores_ = np.zeros((self.n_iters, self._num_cells))
self.all_p_values_ = np.zeros((self.n_iters, self._num_cells))
self.all_log_p_values_ = np.zeros((self.n_iters, self._num_cells))
all_communities = np.zeros((self.n_iters, self._num_cells))
all_parents = []
all_synth_communities = np.zeros((self.n_iters, int(self.boost_rate * self._num_cells)))

for i in range(self.n_iters):
print("Iteration {:3}/{}".format(i + 1, self.n_iters))
self.all_scores_[i], self.all_p_values_[i] = self._one_fit()
self.all_scores_[i], self.all_log_p_values_[i] = self._one_fit()
all_communities[i] = self.communities_
all_parents.append(self.parents_)
all_synth_communities[i] = self.synth_communities_
Expand All @@ -205,10 +209,11 @@ def fit(self, raw_counts):
self.communities_ = all_communities
self.parents_ = all_parents
self.synth_communities_ = all_synth_communities
self.all_p_values_ = np.exp(self.all_log_p_values_)

return self

def predict(self, p_thresh=0.99, voter_thresh=0.9):
def predict(self, p_thresh=0.01, voter_thresh=0.9):
"""Produce doublet calls from fitted classifier
Args:
Expand All @@ -224,11 +229,13 @@ def predict(self, p_thresh=0.99, voter_thresh=0.9):
Returns:
labels_ (ndarray, ndims=1): 0 for singlet, 1 for detected doublet
"""
log_p_thresh = np.log(p_thresh)
if self.n_iters > 1:
with np.errstate(invalid='ignore'): # Silence numpy warning about NaN comparison
self.voting_average_ = np.mean(np.ma.masked_invalid(self.all_p_values_) >= p_thresh,
axis=0)
self.labels_ = np.ma.filled((self.voting_average_ >= voter_thresh).astype(float), np.nan)
self.voting_average_ = np.mean(
np.ma.masked_invalid(self.all_log_p_values_) <= log_p_thresh, axis=0)
self.labels_ = np.ma.filled((self.voting_average_ >= voter_thresh).astype(float),
np.nan)
self.voting_average_ = np.ma.filled(self.voting_average_, np.nan)
else:
# Find a cutoff score
Expand Down Expand Up @@ -288,17 +295,18 @@ def _one_fit(self):
for i in community_IDs}
scores = np.array([community_scores[i] for i in self.communities_])

community_p_values = {i: hypergeom.cdf(synth_cells_per_comm[i], aug_counts.shape[0],
self._synthetics.shape[0],
synth_cells_per_comm[i] + orig_cells_per_comm[i])
for i in community_IDs}
p_values = np.array([community_p_values[i] for i in self.communities_])
community_log_p_values = {i: hypergeom.logsf(synth_cells_per_comm[i], aug_counts.shape[0],
self._synthetics.shape[0],
synth_cells_per_comm[i] +
orig_cells_per_comm[i])
for i in community_IDs}
log_p_values = np.array([community_log_p_values[i] for i in self.communities_])

if min_ID < 0:
scores[self.communities_ == -1] = np.nan
p_values[self.communities_ == -1] = np.nan
log_p_values[self.communities_ == -1] = np.nan

return scores, p_values
return scores, log_p_values

def _downsampleCellPair(self, cell1, cell2):
"""Downsample the sum of two cells' gene expression profiles.
Expand Down
78 changes: 75 additions & 3 deletions src/doubletdetection/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def normalize_counts(raw_counts, pseudocount=0.1):
return normed


def convergence(clf, show=False, save=None, p_thresh=0.99, voter_thresh=0.9):
def convergence(clf, show=False, save=None, p_thresh=0.01, voter_thresh=0.9):
"""Produce a plot showing number of cells called doublet per iter
Args:
Expand All @@ -53,12 +53,14 @@ def convergence(clf, show=False, save=None, p_thresh=0.99, voter_thresh=0.9):
Returns:
matplotlib figure
"""
log_p_thresh = np.log(p_thresh)
doubs_per_run = []
# Ignore numpy complaining about np.nan comparisons
with np.errstate(invalid='ignore'):
for i in range(clf.n_iters):
cum_p_values = clf.all_p_values_[:i + 1]
cum_vote_average = np.mean(np.ma.masked_invalid(cum_p_values) > p_thresh, axis=0)
cum_log_p_values = clf.all_log_p_values_[:i + 1]
cum_vote_average = np.mean(np.ma.masked_invalid(cum_log_p_values) <= log_p_thresh,
axis=0)
cum_doublets = np.ma.filled((cum_vote_average >= voter_thresh).astype(float), np.nan)
doubs_per_run.append(np.nansum(cum_doublets))

Expand Down Expand Up @@ -132,3 +134,73 @@ def tsne(raw_counts, labels, n_components=30, n_jobs=-1, show=False, save=None,
fig.savefig(save, format='pdf', bbox_inches='tight')

return fig, tsne_counts


def threshold(clf, log10=True, show=False, save=None, p_grid=None, voter_grid=None, v_step=20,
p_step=20):
"""Produce a plot showing number of cells called doublet across
various thresholds
Args:
clf (BoostClassifier object): Fitted classifier
log10 (bool, optional): Use natural log p values if False, log10
otherwise.
show (bool, optional): If True, runs plt.show()
save (str, optional): If provided, the figure is saved to this
filepath.
p_grid (ndarray, optional): p-value thresholds to use
voter_grid (ndarray, optional): Voting thresholds to use. Defaults to
np.arange(0.3, 1.0, 0.01).
p_step (int, optional): number of xlabels to skip in plot
v_step (int, optional): number of ylabels to skip in plot
Returns:
matplotlib figure
"""
# Ignore numpy complaining about np.nan comparisons
with np.errstate(invalid='ignore'):
all_p_values = np.copy(clf.all_log_p_values_)
if log10 is True:
our_log = np.log10
all_p_values /= np.log(10)
else:
our_log = np.log
if p_grid is None:
p_grid = np.unique(all_p_values)
p_grid = p_grid[p_grid < our_log(0.01)]
if voter_grid is None:
voter_grid = np.arange(0.3, 1.0, 0.01)
doubs_per_t = np.zeros((len(voter_grid), len(p_grid)))
for i in range(len(voter_grid)):
for j in range(len(p_grid)):
voting_average = np.mean(np.ma.masked_invalid(clf.all_log_p_values_) <= p_grid[j],
axis=0)
labels = np.ma.filled((voting_average >= voter_grid[i]).astype(float), np.nan)
doubs_per_t[i, j] = np.nansum(labels)

# Ignore warning for convergence plot
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", module="matplotlib", message="^tight_layout")

f, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=200)
cax = ax.imshow(doubs_per_t, cmap='hot', aspect='auto')
ax.set_xticks(np.arange(len(p_grid))[::p_step])
ax.set_xticklabels(np.around(p_grid, 1)[::p_step], rotation='vertical')
ax.set_yticks(np.arange(len(voter_grid))[::v_step])
ax.set_yticklabels(np.around(voter_grid, 2)[::v_step])
cbar = f.colorbar(cax)
cbar.set_label('Predicted Doublets')
if log10 is True:
ax.set_xlabel("Log10 p-value")
else:
ax.set_xlabel("Log p-value")
ax.set_ylabel("Voting Threshold")
ax.set_title('Threshold Diagnostics')

if show is True:
plt.show()
if save:
f.savefig(save, format='pdf', bbox_inches='tight')

return f

0 comments on commit 0d807fa

Please sign in to comment.