Skip to content

Commit

Permalink
update normalization methodology
Browse files Browse the repository at this point in the history
  • Loading branch information
kojikoji committed Oct 25, 2021
1 parent c1b8d6a commit 12e4807
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/vicdyf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, s, u, test_ratio, batch_size, num_workers, validation_ratio=0
s = s.float()
u = u.float()
norm_mat = torch.sum(s, dim=1).view(-1, 1) * torch.sum(s, dim=0).view(1, -1)
norm_mat = norm_mat / torch.mean(norm_mat)
norm_mat = torch.mean(s) * norm_mat / torch.mean(norm_mat)
self.s = s
self.u = u
self.norm_mat = norm_mat
Expand Down
5 changes: 2 additions & 3 deletions src/vicdyf/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self.loggamma = Parameter(torch.Tensor(x_dim))
self.logbeta = Parameter(torch.Tensor(x_dim))
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
self.reset_parameters()
self.no_lu = False
self.no_d_kld = False
Expand All @@ -109,9 +110,7 @@ def forward(self, x):
pxd_zd_ld = self.dec_z(z + d)
pxmd_zd_ld = self.dec_z(z - d)
diff_px_zd_ld = pxd_zd_ld - pxmd_zd_ld
raw_gamma = self.softplus(self.loggamma)
normalize_coeff = self.gamma_mean / raw_gamma.mean()
gamma = self.dt * normalize_coeff * raw_gamma
gamma = self.softplus(self.loggamma)
beta = self.softplus(self.logbeta) * self.dt
pu_zd_ld = self.softplus(diff_px_zd_ld + px_z_ld * gamma) / beta
return(z, d, qz, qd, px_z_ld, pu_zd_ld)
Expand Down
32 changes: 26 additions & 6 deletions src/vicdyf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def calc_gene_mean_sd(z, qd, dcoeff, model, sample_num=50):
return(gene_mean, gene_sd, batch_mean_mat, batch_std_mat)


def post_process(adata, vicdyf_exp, sigma=0.05, nn=30, mdist=0.1, dz_var_prop=0.05, sample_num=10):
def post_process(adata, vicdyf_exp, sigma=0.05, n_neighbors=30, min_dist=0.1, dz_var_prop=0.05, sample_num=10):
x = vicdyf_exp.edm.s
u = vicdyf_exp.edm.u
vicdyf_exp.device = torch.device('cpu')
Expand Down Expand Up @@ -135,18 +135,38 @@ def post_process(adata, vicdyf_exp, sigma=0.05, nn=30, mdist=0.1, dz_var_prop=0.
adata.layers['vicdyf_velocity'] = gene_vel
adata.layers['vicdyf_mean_velocity'] = mean_gene_vel
adata.layers['vicdyf_fluctuation'] = batch_std_mat
adata.obs['vicdyf_fluctuation'] = np.mean(adata.layers['vicdyf_fluctuation'])
adata.obs['vicdyf_velocity'] = np.mean(np.abs(adata.layers['vicdyf_velocity']))
adata.obs['vicdyf_fluctuation'] = np.mean(adata.layers['vicdyf_fluctuation'], axis=1)
adata.obs['vicdyf_velocity'] = np.mean(np.abs(adata.layers['vicdyf_velocity']), axis=1)
adata.obs['vicdyf_mean_velocity'] = np.mean(np.abs(adata.layers['vicdyf_mean_velocity']), axis=1)
# calculate transition rate
stoc_tr_mat = calc_tr_mat(zl.cpu().detach(), d.cpu().detach(), sigma)
mean_tr_mat = calc_tr_mat(zl.cpu().detach(), dl.cpu().detach(), sigma)
# embed z
z_embed = embed_z(zl_mat, n_neighbors=nn, min_dist=mdist)
z_embed = embed_z(zl_mat, n_neighbors=n_neighbors, min_dist=min_dist)
adata.obsm['X_vicdyf_umap'] = z_embed
stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, gene_norm)
mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, mean_gene_norm)
stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, adata.obs['vicdyf_velocity'].values)
mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, adata.obs['vicdyf_mean_velocity'].values)
adata.obsp['stoc_tr_mat'] = stoc_tr_mat.detach().numpy()
adata.obsp['mean_tr_mat'] = mean_tr_mat.detach().numpy()
adata.obsm['X_vicdyf_sdumap'] = stoc_d_embed.cpu().detach().numpy()
adata.obsm['X_vicdyf_mdumap'] = mean_d_embed.cpu().detach().numpy()
return(adata)

def change_visualization(adata, embeddings=None, n_neighbors=30, min_dist=0.1):
# embed z
if embeddings == None:
z_embed = embed_z(adata.obsm['X_vicdyf_zl'], n_neighbors=n_neighbors, min_dist=min_dist)
else:
if type(embeddings) == str:
z_embed = adata.obsm[embeddings]
else:
z_embed = embeddings
adata.obsm['X_vicdyf_umap'] = z_embed
stoc_tr_mat = torch.tensor(adata.obsp['stoc_tr_mat'])
mean_tr_mat = torch.tensor(adata.obsp['mean_tr_mat'])
stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, adata.obs['vicdyf_velocity'].values)
mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, adata.obs['vicdyf_mean_velocity'].values)
adata.obsm['X_vicdyf_sdumap'] = stoc_d_embed.cpu().detach().numpy()
adata.obsm['X_vicdyf_mdumap'] = mean_d_embed.cpu().detach().numpy()
return(adata)

11 changes: 6 additions & 5 deletions src/vicdyf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ def estimate_dynamics(
'num_enc_z_layers': 2, 'num_enc_d_layers': 2,
'num_dec_z_layers': 2
},
lr=0.0001, val_ratio=0.05, test_ratio=0.1,
batch_size=300, num_workers=1, sample_num=10):
lr=0.001, val_ratio=0, test_ratio=0.05,
batch_size=100, num_workers=1, sample_num=10,
n_neighbors=30, min_dist=0.1):
if use_genes == None:
use_genes = adata.var_names
utils.input_checks(adata)
Expand All @@ -28,7 +29,7 @@ def estimate_dynamics(
print('Start first opt')
for param in vicdyf_exp.model.enc_d.parameters():
param.requires_grad = False
vicdyf_exp.init_optimizer(0.0001)
vicdyf_exp.init_optimizer(lr)
vicdyf_exp.train_total(first_epoch)
print('Done first opt')
print(f'Loss:{vicdyf_exp.test()}')
Expand All @@ -37,11 +38,11 @@ def estimate_dynamics(
vicdyf_exp.model.no_d_kld = False
for param in vicdyf_exp.model.enc_d.parameters():
param.requires_grad = True
vicdyf_exp.init_optimizer(0.0001)
vicdyf_exp.init_optimizer(lr)
vicdyf_exp.train_total(second_epoch)
print('Done second opt')
print(f'Loss:{vicdyf_exp.test()}')
torch.save(vicdyf_exp.model.state_dict(), param_path)
adata.uns['param_path'] = param_path
adata = utils.post_process(adata, vicdyf_exp, sample_num=sample_num)
adata = utils.post_process(adata, vicdyf_exp, sample_num=sample_num, n_neighbors=n_neighbors, min_dist=min_dist)
return(adata)
9 changes: 7 additions & 2 deletions tutorial/application_on_pancreas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import vicdyf
import scvelo as scv
from matplotlib import pyplot as plt
adata = scv.datasets.pancreas()
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=4000)
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
raw_adata = scv.datasets.pancreas()
adata.layers['spliced'] = raw_adata[:, adata.var_names].layers['spliced']
adata.layers['unspliced'] = raw_adata[:, adata.var_names].layers['unspliced']
adata = vicdyf.workflow.estimate_dynamics(adata)
adata = vicdyf.workflow.estimate_dynamics(adata)#, first_epoch=10, second_epoch=10)
adata = vicdyf.utils.change_visualization(adata, n_neighbors=100)
adata = vicdyf.utils.change_visualization(adata, embeddings='X_vicdyf_umap')
scv.pl.velocity_embedding_grid(adata,X=adata.obsm['X_vicdyf_umap'], V=adata.obsm['X_vicdyf_mdumap'], color='vicdyf_fluctuation', show=False, basis='X_vicdyf_umap', density=0.3)
plt.savefig('tutorial/pancreas_flow.png')

0 comments on commit 12e4807

Please sign in to comment.