diff --git a/src/vicdyf/dataset.py b/src/vicdyf/dataset.py index 0481c57..c2d6035 100644 --- a/src/vicdyf/dataset.py +++ b/src/vicdyf/dataset.py @@ -23,7 +23,7 @@ def __init__(self, s, u, test_ratio, batch_size, num_workers, validation_ratio=0 s = s.float() u = u.float() norm_mat = torch.sum(s, dim=1).view(-1, 1) * torch.sum(s, dim=0).view(1, -1) - norm_mat = norm_mat / torch.mean(norm_mat) + norm_mat = torch.mean(s) * norm_mat / torch.mean(norm_mat) self.s = s self.u = u self.norm_mat = norm_mat diff --git a/src/vicdyf/modules.py b/src/vicdyf/modules.py index 8a21f70..155497f 100644 --- a/src/vicdyf/modules.py +++ b/src/vicdyf/modules.py @@ -84,6 +84,7 @@ def __init__( self.loggamma = Parameter(torch.Tensor(x_dim)) self.logbeta = Parameter(torch.Tensor(x_dim)) self.softplus = nn.Softplus() + self.sigmoid = nn.Sigmoid() self.reset_parameters() self.no_lu = False self.no_d_kld = False @@ -109,9 +110,7 @@ def forward(self, x): pxd_zd_ld = self.dec_z(z + d) pxmd_zd_ld = self.dec_z(z - d) diff_px_zd_ld = pxd_zd_ld - pxmd_zd_ld - raw_gamma = self.softplus(self.loggamma) - normalize_coeff = self.gamma_mean / raw_gamma.mean() - gamma = self.dt * normalize_coeff * raw_gamma + gamma = self.softplus(self.loggamma) beta = self.softplus(self.logbeta) * self.dt pu_zd_ld = self.softplus(diff_px_zd_ld + px_z_ld * gamma) / beta return(z, d, qz, qd, px_z_ld, pu_zd_ld) diff --git a/src/vicdyf/utils.py b/src/vicdyf/utils.py index bbdc9ca..90e0204 100644 --- a/src/vicdyf/utils.py +++ b/src/vicdyf/utils.py @@ -97,7 +97,7 @@ def calc_gene_mean_sd(z, qd, dcoeff, model, sample_num=50): return(gene_mean, gene_sd, batch_mean_mat, batch_std_mat) -def post_process(adata, vicdyf_exp, sigma=0.05, nn=30, mdist=0.1, dz_var_prop=0.05, sample_num=10): +def post_process(adata, vicdyf_exp, sigma=0.05, n_neighbors=30, min_dist=0.1, dz_var_prop=0.05, sample_num=10): x = vicdyf_exp.edm.s u = vicdyf_exp.edm.u vicdyf_exp.device = torch.device('cpu') @@ -135,18 +135,38 @@ def post_process(adata, vicdyf_exp, sigma=0.05, nn=30, mdist=0.1, dz_var_prop=0. adata.layers['vicdyf_velocity'] = gene_vel adata.layers['vicdyf_mean_velocity'] = mean_gene_vel adata.layers['vicdyf_fluctuation'] = batch_std_mat - adata.obs['vicdyf_fluctuation'] = np.mean(adata.layers['vicdyf_fluctuation']) - adata.obs['vicdyf_velocity'] = np.mean(np.abs(adata.layers['vicdyf_velocity'])) + adata.obs['vicdyf_fluctuation'] = np.mean(adata.layers['vicdyf_fluctuation'], axis=1) + adata.obs['vicdyf_velocity'] = np.mean(np.abs(adata.layers['vicdyf_velocity']), axis=1) + adata.obs['vicdyf_mean_velocity'] = np.mean(np.abs(adata.layers['vicdyf_mean_velocity']), axis=1) # calculate transition rate stoc_tr_mat = calc_tr_mat(zl.cpu().detach(), d.cpu().detach(), sigma) mean_tr_mat = calc_tr_mat(zl.cpu().detach(), dl.cpu().detach(), sigma) # embed z - z_embed = embed_z(zl_mat, n_neighbors=nn, min_dist=mdist) + z_embed = embed_z(zl_mat, n_neighbors=n_neighbors, min_dist=min_dist) adata.obsm['X_vicdyf_umap'] = z_embed - stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, gene_norm) - mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, mean_gene_norm) + stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, adata.obs['vicdyf_velocity'].values) + mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, adata.obs['vicdyf_mean_velocity'].values) adata.obsp['stoc_tr_mat'] = stoc_tr_mat.detach().numpy() adata.obsp['mean_tr_mat'] = mean_tr_mat.detach().numpy() adata.obsm['X_vicdyf_sdumap'] = stoc_d_embed.cpu().detach().numpy() adata.obsm['X_vicdyf_mdumap'] = mean_d_embed.cpu().detach().numpy() return(adata) + +def change_visualization(adata, embeddings=None, n_neighbors=30, min_dist=0.1): + # embed z + if embeddings == None: + z_embed = embed_z(adata.obsm['X_vicdyf_zl'], n_neighbors=n_neighbors, min_dist=min_dist) + else: + if type(embeddings) == str: + z_embed = adata.obsm[embeddings] + else: + z_embed = embeddings + adata.obsm['X_vicdyf_umap'] = z_embed + stoc_tr_mat = torch.tensor(adata.obsp['stoc_tr_mat']) + mean_tr_mat = torch.tensor(adata.obsp['mean_tr_mat']) + stoc_d_embed = embed_tr_mat(z_embed, stoc_tr_mat, adata.obs['vicdyf_velocity'].values) + mean_d_embed =embed_tr_mat(z_embed, mean_tr_mat, adata.obs['vicdyf_mean_velocity'].values) + adata.obsm['X_vicdyf_sdumap'] = stoc_d_embed.cpu().detach().numpy() + adata.obsm['X_vicdyf_mdumap'] = mean_d_embed.cpu().detach().numpy() + return(adata) + diff --git a/src/vicdyf/workflow.py b/src/vicdyf/workflow.py index 09190ea..c104452 100644 --- a/src/vicdyf/workflow.py +++ b/src/vicdyf/workflow.py @@ -11,8 +11,9 @@ def estimate_dynamics( 'num_enc_z_layers': 2, 'num_enc_d_layers': 2, 'num_dec_z_layers': 2 }, - lr=0.0001, val_ratio=0.05, test_ratio=0.1, - batch_size=300, num_workers=1, sample_num=10): + lr=0.001, val_ratio=0, test_ratio=0.05, + batch_size=100, num_workers=1, sample_num=10, + n_neighbors=30, min_dist=0.1): if use_genes == None: use_genes = adata.var_names utils.input_checks(adata) @@ -28,7 +29,7 @@ def estimate_dynamics( print('Start first opt') for param in vicdyf_exp.model.enc_d.parameters(): param.requires_grad = False - vicdyf_exp.init_optimizer(0.0001) + vicdyf_exp.init_optimizer(lr) vicdyf_exp.train_total(first_epoch) print('Done first opt') print(f'Loss:{vicdyf_exp.test()}') @@ -37,11 +38,11 @@ def estimate_dynamics( vicdyf_exp.model.no_d_kld = False for param in vicdyf_exp.model.enc_d.parameters(): param.requires_grad = True - vicdyf_exp.init_optimizer(0.0001) + vicdyf_exp.init_optimizer(lr) vicdyf_exp.train_total(second_epoch) print('Done second opt') print(f'Loss:{vicdyf_exp.test()}') torch.save(vicdyf_exp.model.state_dict(), param_path) adata.uns['param_path'] = param_path - adata = utils.post_process(adata, vicdyf_exp, sample_num=sample_num) + adata = utils.post_process(adata, vicdyf_exp, sample_num=sample_num, n_neighbors=n_neighbors, min_dist=min_dist) return(adata) diff --git a/tutorial/application_on_pancreas.py b/tutorial/application_on_pancreas.py index 1bcf8f2..7224b70 100644 --- a/tutorial/application_on_pancreas.py +++ b/tutorial/application_on_pancreas.py @@ -1,9 +1,14 @@ import vicdyf import scvelo as scv +from matplotlib import pyplot as plt adata = scv.datasets.pancreas() -scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=4000) +scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) scv.pp.moments(adata, n_pcs=30, n_neighbors=30) raw_adata = scv.datasets.pancreas() adata.layers['spliced'] = raw_adata[:, adata.var_names].layers['spliced'] adata.layers['unspliced'] = raw_adata[:, adata.var_names].layers['unspliced'] -adata = vicdyf.workflow.estimate_dynamics(adata) +adata = vicdyf.workflow.estimate_dynamics(adata)#, first_epoch=10, second_epoch=10) +adata = vicdyf.utils.change_visualization(adata, n_neighbors=100) +adata = vicdyf.utils.change_visualization(adata, embeddings='X_vicdyf_umap') +scv.pl.velocity_embedding_grid(adata,X=adata.obsm['X_vicdyf_umap'], V=adata.obsm['X_vicdyf_mdumap'], color='vicdyf_fluctuation', show=False, basis='X_vicdyf_umap', density=0.3) +plt.savefig('tutorial/pancreas_flow.png')