Skip to content

Commit

Permalink
global overdispersion, no Dirichlet
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfleming committed May 29, 2020
1 parent ead9481 commit 5134191
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 139 deletions.
8 changes: 8 additions & 0 deletions REQUIREMENTS-DOCKER.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
numpy
scipy
tables
pandas
pyro-ppl>=0.3.2
torch
scikit-learn
matplotlib
26 changes: 13 additions & 13 deletions cellbender/remove_background/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ def add_subparser_args(subparsers: argparse) -> argparse:
"correct prior for empty droplet counts "
"in the rare case where empty counts "
"are extremely high (over 200).")
subparser.add_argument("--z-dim", type=int, default=20,
subparser.add_argument("--z-dim", type=int, default=100,
dest="z_dim",
help="Dimension of latent variable z.")
subparser.add_argument("--z-layers", nargs="+", type=int, default=[500],
dest="z_hidden_dims",
help="Dimension of hidden layers in the encoder "
"for z.")
subparser.add_argument("--d-layers", nargs="+", type=int,
default=[5, 2, 2],
dest="d_hidden_dims",
help="Dimension of hidden layers in the encoder "
"for d.")
subparser.add_argument("--p-layers", nargs="+", type=int,
default=[100, 10],
dest="p_hidden_dims",
help="Dimension of hidden layers in the encoder "
"for p.")
# subparser.add_argument("--d-layers", nargs="+", type=int,
# default=[5, 2, 2],
# dest="d_hidden_dims",
# help="Dimension of hidden layers in the encoder "
# "for d.")
# subparser.add_argument("--p-layers", nargs="+", type=int,
# default=[100, 10],
# dest="p_hidden_dims",
# help="Dimension of hidden layers in the encoder "
# "for p.")
subparser.add_argument("--training-fraction",
type=float, nargs=None,
default=consts.TRAINING_FRACTION,
Expand All @@ -125,7 +125,7 @@ def add_subparser_args(subparsers: argparse) -> argparse:
"the counts for these genes will be set "
"to zero.")
subparser.add_argument("--fpr", nargs="+",
type=float, default=[0.02],
type=float, default=[0.01],
dest="fpr",
help="Target false positive rate in (0, 1). A false "
"positive is a true signal count that is "
Expand All @@ -140,7 +140,7 @@ def add_subparser_args(subparsers: argparse) -> argparse:
"will cause remove-background to operate on "
"gene counts only, ignoring other features.")
subparser.add_argument("--learning-rate", nargs=None,
type=float, default=5e-4,
type=float, default=2e-4,
dest="learning_rate",
help="Training detail: lower learning rate for "
"inference. A OneCycle learning rate schedule "
Expand Down
4 changes: 2 additions & 2 deletions cellbender/remove_background/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@

# Prior mean and std in log space for alpha0, the Dirichlet precision for cells.
# ALPHA0_PRIOR_LOC = 7. # TODO: using a heuristic in the model
ALPHA0_PRIOR_SCALE = 2.
# ALPHA0_PRIOR_SCALE = 2.

# Prior on rho, the swapping fraction: the two concentration parameters alpha and beta.
RHO_ALPHA_PRIOR = 18. # 1.5
RHO_BETA_PRIOR = 200. # 20.

# Prior on epsilon, the RT efficiency concentration parameter [Gamma(alpha, alpha)].
EPSILON_PRIOR = 20. # 1000.
EPSILON_PRIOR = 500. # 20. # 1000.
19 changes: 10 additions & 9 deletions cellbender/remove_background/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,9 @@ def save_to_output_file(
z = enc['z']
d = enc['d']
p = enc['p']
alpha0 = enc['alpha0']
# alpha0 = enc['alpha0']
epsilon = enc['epsilon']
phi_params = enc['phi_loc_scale']

# Estimate the ambient-background-subtracted UMI count matrix.
if self.model_name != "simple":
Expand Down Expand Up @@ -566,7 +567,7 @@ def save_to_output_file(
inferred_count_matrix=inferred_count_matrix,
cell_barcode_inds=cell_barcode_inds,
ambient_expression=ambient_expression,
z=z, d=d, p=p, alpha=alpha0, epsilon=epsilon,
z=z, d=d, p=p, phi=phi_params, epsilon=epsilon,
rho=cellbender.remove_background.model.get_rho(),
fpr=self.fpr[0],
lambda_multiplier=self.posterior.lambda_multiplier,
Expand Down Expand Up @@ -598,7 +599,7 @@ def save_to_output_file(
z=z[filtered_inds_of_analyzed_barcodes, :],
d=d[filtered_inds_of_analyzed_barcodes],
p=p[filtered_inds_of_analyzed_barcodes],
alpha=alpha0[filtered_inds_of_analyzed_barcodes],
phi=phi_params,
epsilon=epsilon[filtered_inds_of_analyzed_barcodes],
rho=cellbender.remove_background.model.get_rho(),
fpr=self.fpr[0],
Expand Down Expand Up @@ -633,7 +634,7 @@ def save_to_output_file(
inferred_count_matrix=inferred_count_matrix,
cell_barcode_inds=self.analyzed_barcode_inds,
ambient_expression=ambient_expression,
z=z, d=d, p=p, alpha=alpha0, epsilon=epsilon,
z=z, d=d, p=p, phi=phi_params, epsilon=epsilon,
rho=cellbender.remove_background.model.get_rho(),
fpr=fpr,
lambda_multiplier=self.posterior.lambda_multiplier,
Expand Down Expand Up @@ -663,7 +664,7 @@ def save_to_output_file(
z=z[filtered_inds_of_analyzed_barcodes, :],
d=d[filtered_inds_of_analyzed_barcodes],
p=p[filtered_inds_of_analyzed_barcodes],
alpha=alpha0[filtered_inds_of_analyzed_barcodes],
phi=phi_params,
epsilon=epsilon[filtered_inds_of_analyzed_barcodes],
rho=cellbender.remove_background.model.get_rho(),
fpr=fpr,
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def write_matrix_to_cellranger_h5(
z: Union[np.ndarray, None] = None,
d: Union[np.ndarray, None] = None,
p: Union[np.ndarray, None] = None,
alpha: Union[np.ndarray, None] = None,
phi: Union[np.ndarray, None] = None,
epsilon: Union[np.ndarray, None] = None,
fpr: Union[float, None] = None,
lambda_multiplier: Union[float, None] = None,
Expand All @@ -1107,7 +1108,7 @@ def write_matrix_to_cellranger_h5(
z: Latent encoding of gene expression.
d: Latent cell size scale factor.
p: Latent probability that a barcode contains a cell.
alpha: Latent Dirichlet precision parameter for each cell.
phi: Latent global overdispersion mean and scale.
fpr: Target false positive rate for the regularized posterior denoised
counts, where false positives are true counts that are (erroneously)
removed.
Expand Down Expand Up @@ -1215,8 +1216,8 @@ def write_matrix_to_cellranger_h5(
f.create_array(group, "latent_scale", d)
if p is not None:
f.create_array(group, "latent_cell_probability", p)
if alpha is not None:
f.create_array(group, "latent_dirichlet_precision", alpha)
if phi is not None:
f.create_array(group, "overdispersion_mean_and_scale", phi)
if rho is not None:
f.create_array(group, "contamination_fraction_params", rho)
if epsilon is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def log_prob(self, value):
"""
if self._validate_args:
self._validate_sample(value)
mu, alpha, lam, value = broadcast_all(self.mu, self.alpha, self.lam, value)
# mu, alpha, lam, value = broadcast_all(self.mu, self.alpha, self.lam, value)
mu, lam, value = broadcast_all(self.mu, self.lam, value)

# Use a moment-matched negative binomial approximation.
mean_hat = mu + lam
alpha_hat = mean_hat.pow(2) * alpha * mu.pow(-2)
alpha_hat = mean_hat.pow(2) * self.alpha * mu.pow(-2)
nb_approx_log_prob = self._neg_binom_log_prob(mu=mean_hat,
alpha=alpha_hat,
value=value)
Expand All @@ -123,9 +124,9 @@ def log_prob(self, value):
if torch.isnan(mu.log().sum()):
param.append('mu')
print(f'mu problem values: {mu[torch.isnan(mu.log())]}')
if torch.isnan(alpha.log().sum()):
if torch.isnan(self.alpha.log().sum()):
param.append('alpha')
print(f'alpha problem values: {alpha[torch.isnan(alpha.log())]}')
print(f'alpha value: {self.alpha}')
if torch.isnan(lam.log().sum()):
param.append('lam')
print(f'lam problem values: {lam[torch.isnan(lam.log())]}')
Expand Down
Loading

0 comments on commit 5134191

Please sign in to comment.