Skip to content

Commit

Permalink
publication-quality figure, testing passed
Browse files Browse the repository at this point in the history
  • Loading branch information
frankligy committed Jun 23, 2021
1 parent be7ff5a commit 65530fd
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4,240 deletions.
76 changes: 56 additions & 20 deletions sctriangulate/main_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def add_to_invalid(self,invalid):
tmp = list(set(self.invalid))
self.invalid = tmp

def add_to_invalid_by_win_fraction(self,percent):
def add_to_invalid_by_win_fraction(self,percent=0.25):
df = self.uns['raw_cluster_goodness']
invalid = df.loc[df['win_fraction']<percent,:].index.tolist()
self.add_to_invalid(invalid)
Expand All @@ -222,7 +222,7 @@ def add_new_metrics(self,add_metrics):
self.add_metrics[metric] = func
self.total_metrics.extend(list(self.add_metrics.keys()))

def winners_statistics(self,col,plot,save):
def plot_winners_statistics(self,col,plot=True,save=True):
new_size_dict = {} # {gs@ERP4: 100}
for key,value in self.size_dict.items():
for sub_key,sub_value in value.items():
Expand All @@ -237,7 +237,6 @@ def winners_statistics(self,col,plot,save):
winners_stats = pd.concat([winners_vc,winners_size,winners_prop],axis=1)
winners_stats.columns = ['counts','size','proportion']
winners_stats.sort_values(by='proportion',inplace=True)
self.winners_stats = winners_stats
if plot:
a = winners_stats['proportion']
fig,ax = plt.subplots()
Expand All @@ -249,8 +248,9 @@ def winners_statistics(self,col,plot,save):
if save:
plt.savefig(os.path.join(self.dir,'winners_statistics.pdf'),bbox_inches='tight')
plt.close()
return winners_stats

def clusterability(self,col,plot,save):
def plot_clusterability(self,col,plot=True,save=True):
bucket = {} # {ERP4:5}
obs = self.adata.obs
for ref,grouped_df in obs.groupby(by=self.reference):
Expand All @@ -259,17 +259,24 @@ def clusterability(self,col,plot,save):
bucket = {k: v for k, v in sorted(bucket.items(), key=lambda x: x[1])}
if plot:
fig,ax = plt.subplots()
ax.scatter(x=np.arange(len(bucket)),y=list(bucket.values()),c='k',s=15)
ax.scatter(x=np.arange(len(bucket)),y=list(bucket.values()),c=pick_n_colors(len(bucket)),s=100)
ax.set_xticks(np.arange(len(bucket)))
ax.set_xticklabels(list(bucket.keys()),fontsize=3,rotation=90)
ax.set_title('{} clusterablity'.format(self.reference))
ax.set_ylabel('clusterability: # sub-clusters')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.grid(color='grey',alpha=0.2)
for i in range(len(bucket)):
ax.text(x=i,y=list(bucket.values())[i]+1,s=list(bucket.keys())[i],ha='center',va='bottom')

if save:
plt.savefig(os.path.join(self.dir,'{}_clusterability.pdf'.format(self.reference)),bbox_inches='tight')
plt.close()
return bucket


def display_hierarchy(self,col):
def display_hierarchy(self,col,save=True):
obs = self.adata.obs
root = Node(self.reference)
hold_ref_var = {}
Expand All @@ -282,9 +289,13 @@ def display_hierarchy(self,col):
hold_cluster_var = {}
for item in unique:
hold_cluster_var[item] = Node(item,parent=hold_ref_var[ref])
with open(os.path.join(self.dir,'display_hierarchy_{}_{}.txt'.format(self.reference,col)),'a') as f:
if save:
with open(os.path.join(self.dir,'display_hierarchy_{}_{}.txt'.format(self.reference,col)),'a') as f:
for pre, fill, node in RenderTree(root):
print("%s%s" % (pre, node.name),file=f)
else:
for pre, fill, node in RenderTree(root):
print("%s%s" % (pre, node.name),file=f)
print("%s%s" % (pre, node.name))


def prune_statistics(self,print=False):
Expand Down Expand Up @@ -567,6 +578,7 @@ def pruning(self,method='reassign',discard=None,scale_sccaf=True,abs_thresh=10,r
obs, df = rank_pruning(self,discard=discard,scale_sccaf=scale_sccaf)
self.adata.obs = obs
self.uns['raw_cluster_goodness'] = df
self.adata.obs['confidence'] = self.adata.obs['pruned'].map(df['win_fraction'].to_dict())

self._prefixing(col='pruned')

Expand Down Expand Up @@ -594,30 +606,34 @@ def get_cluster(self):
self.adata.obs['user_choice'] = self.adata.obs['prefixed'].map(mapping).values


def plot_umap(self,col,kind='category',save=True,format='pdf'):
def plot_umap(self,col,kind='category',save=True,format='pdf',umap_dot_size=None):
# col means which column in obs to draw umap on
if umap_dot_size is None:
dot_size = 120000/self.adata.obs.shape[0]
else:
dot_size = umap_dot_size
if kind == 'category':
fig,ax = plt.subplots(nrows=2,ncols=1,figsize=(8,20),gridspec_kw={'hspace':0.3}) # for final_annotation
sc.pl.umap(self.adata,color=col,frameon=False,ax=ax[0])
sc.pl.umap(self.adata,color=col,frameon=False,legend_loc='on data',legend_fontsize=5,ax=ax[1])
sc.pl.umap(self.adata,color=col,frameon=False,ax=ax[0],size=dot_size)
sc.pl.umap(self.adata,color=col,frameon=False,legend_loc='on data',legend_fontsize=5,ax=ax[1],size=dot_size)
if save:
plt.savefig(os.path.join(self.dir,'umap_sctriangulate_{}.{}'.format(col,format)),bbox_inches='tight')
plt.close()
elif kind == 'continuous':
sc.pl.umap(self.adata,color=col,frameon=False,cmap=bg_greyed_cmap('viridis'),vmin=1e-5)
sc.pl.umap(self.adata,color=col,frameon=False,cmap=bg_greyed_cmap('viridis'),vmin=1e-5,size=dot_size)
if save:
plt.savefig(os.path.join(self.dir,'umap_sctriangulate_{}.{}'.format(col,format)),bbox_inches='tight')
plt.close()

def plot_confusion(self,name,key,save,format='pdf',**kwargs):
def plot_confusion(self,name,key,save=True,format='pdf',**kwargs):
df = self.uns[name][key]
df = df.apply(func=lambda x:x/x.sum(),axis=1)
sns.heatmap(df,cmap=scphere_cmap,**kwargs)
if save:
plt.savefig(os.path.join(self.dir,'confusion_{}_{}.{}'.format(name,key,format)),bbox_inches='tight')
plt.close()

def plot_cluster_feature(self,key,cluster,feature,enrichment_type='enrichr',save=False,format='pdf'):
def plot_cluster_feature(self,key,cluster,feature,enrichment_type='enrichr',save=True,format='pdf'):
if feature == 'enrichment':
fig,ax = plt.subplots()
a = self.uns['marker_genes'][key].loc[cluster,:][enrichment_type]
Expand Down Expand Up @@ -653,13 +669,18 @@ def plot_cluster_feature(self,key,cluster,feature,enrichment_type='enrichr',save
plt.savefig(os.path.join(self.dir,'{0}_{1}_location_umap.{2}'.format(key,cluster,format)),bbox_inches='tight')
plt.close()

def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=None):
def plot_heterogeneity(self,key,cluster,style,col='pruned',save=True,format='pdf',genes=None,umap_zoom_out=True,umap_dot_size=None,
subset=None,marker_gene_dict=None,jitter=True,rotation=60):
adata_s = self.adata[self.adata.obs[key]==cluster,:]
# remove prior color stamps
tmp = adata_s.uns
tmp.pop('{}_colors'.format(col),None)
adata_s.uns = tmp

# only consider the sub-populations in subset list
if subset is not None:
adata_s = adata_s[adata_s.obs[col].isin(subset),:]

if style == 'build': # draw umap and heatmap

# umap
Expand Down Expand Up @@ -697,7 +718,17 @@ def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=
elif style == 'umap':
fig,axes = plt.subplots(nrows=2,ncols=1,gridspec_kw={'hspace':0.5},figsize=(5,10))
# ax1
sc.pl.umap(adata_s,color=[col],ax=axes[0])
if umap_zoom_out:
umap_whole = self.adata.obsm['X_umap']
umap_x_lim = (umap_whole[:,0].min(),umap_whole[:,0].max())
umap_y_lum = (umap_whole[:,1].min(),umap_whole[:,1].max())
axes[0].set_xlim(umap_x_lim)
axes[0].set_ylim(umap_x_lim)
if umap_dot_size is None:
sc.pl.umap(adata_s,color=[col],ax=axes[0],size=120000/self.adata.obs.shape[0])
else:
sc.pl.umap(adata_s,color=[col],ax=axes[0],size=umap_dot_size)

# ax2
tmp_col = [1 if item == str(cluster) else 0 for item in self.adata.obs[key]]
self.adata.obs['tmp_plot'] = tmp_col
Expand All @@ -719,7 +750,6 @@ def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=
adata_s = filter_DE_genes(adata_s,self.species,self.criterion)
number_of_groups = len(adata_s.obs[col].unique())
genes_to_pick = 50 // number_of_groups
print(adata_s.uns.keys())
sc.pl.rank_genes_groups_heatmap(adata_s,n_genes=genes_to_pick,swap_axes=True,key='rank_genes_groups_filtered')
if save:
plt.savefig(os.path.join(self.dir,'{}_{}_heterogeneity_{}_{}.{}'.format(key,cluster,col,style,format)),bbox_inches='tight')
Expand All @@ -728,7 +758,6 @@ def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=
sc_marker_dict = {} # key is subgroup, value is a df containing markers
col_dict = {} # key is a colname, value is a numpy record array
colnames = ['names','scores','pvals','pvals_adj','logfoldchanges']
print(adata_s.uns.keys())
for item in colnames:
col_dict[item] = adata_s.uns['rank_genes_groups_filtered'][item]
for group in adata_s.obs[col].unique():
Expand All @@ -740,8 +769,14 @@ def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=
sc_marker_dict[group] = df
return sc_marker_dict

elif style == 'heatmap_custom_gene':
sc.pl.heatmap(adata_s,marker_gene_dict,groupby=col,swap_axes=True,dendrogram=True)
if save:
plt.savefig(os.path.join(self.dir,'{}_{}_heterogeneity_{}_{}.{}'.format(key,cluster,col,style,format)),bbox_inches='tight')
plt.close()

elif style == 'violin':
sc.pl.violin(adata_s,genes,groupby=col)
sc.pl.violin(adata_s,genes,groupby=col,rotation=rotation,jitter=jitter)
if save:
genes = '_'.join(genes)
plt.savefig(os.path.join(self.dir,'{}_{}_heterogeneity_{}_{}_{}.{}'.format(key,cluster,col,genes,style,format)),bbox_inches='tight')
Expand Down Expand Up @@ -841,12 +876,13 @@ def plot_heterogeneity(self,key,cluster,col,style,save=False,format='pdf',genes=
fig.write_image(os.path.join(self.dir,'{}_{}_heterogeneity_{}_{}.{}'.format(key,cluster,col,style,format)))


def plot_circular_barplot(self,key,col,save=False,format='pdf'):
def plot_circular_barplot(self,key,col,save=True,format='pdf'):
# col can be 'raw' or 'pruned'
obs = copy.deepcopy(self.adata.obs)
reference = key
obs['value'] = np.full(shape=obs.shape[0], fill_value=1)
obs = obs.loc[:, [reference, col, 'value']]
print(obs)
obs4plot = obs.groupby(by=[reference, col])['value'].sum().reset_index()
print(obs.groupby(by=[reference, col])['value'])
print(obs.groupby(by=[reference, col])['value'].sum())
Expand Down
108 changes: 108 additions & 0 deletions test_sctriangulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

'''This script is to test sctriangulate program,
Using pbmc3k dataset'''

import os
import sys
import scanpy as sc
import pandas as pd
import numpy as np
os.chdir('/Users/ligk2e/Desktop/scTriangulate')
sys.path.append('.')
from sctriangulate import *

# load the data
adata = sc.read('pbmc3k_azimuth_umap.h5ad')

# instantiation
sctri = ScTriangulate(dir='./output',adata=adata,query=['leiden1','leiden2','leiden3'])

# main program
sctri.compute_metrics(parallel=True,scale_sccaf=False)
sctri.compute_shapley(parallel=True)
sctri.pruning(method='rank',scale_sccaf=False)

# clean step
sctri.add_to_invalid_by_win_fraction(percent=0.25)
sctri.pruning(method='reassign',abs_thresh=10,remove1=True,reference='leiden1')

# IO
sctri.serialize(name='after_prune_rank.p')
sctri = ScTriangulate.deserialize('output/after_prune_rank.p')


# generate viewer
sctri.viewer_cluster_feature_figure()
sctri.viewer_cluster_feature_html()
sctri.add_to_invalid_by_win_fraction(percent=0.25)
sctri.pruning(method='reassign',abs_thresh=10,remove1=True,reference='leiden3')
sctri.viewer_heterogeneity_figure(keys=['leiden3'])
sctri.viewer_heterogeneity_html(key='leiden3')

'''test plot_heterogeneity function'''
# umap
sctri.plot_heterogeneity('leiden1','0','umap',subset=['leiden1@0','leiden3@10'])

# heatmap
sctri.plot_heterogeneity('leiden1','0','heatmap',subset=['leiden1@0','leiden3@10'])

# heatmap_custom_gene
marker_gene_dict = {
'leiden1@0':['MAPK14','RASSF1','PARP10'],
'leiden3@10':['ANXA1','CD52','CMTM7']
}
sctri.plot_heterogeneity('leiden1','0','heatmap_custom_gene',subset=['leiden1@0','leiden3@10'],marker_gene_dict=marker_gene_dict)

# violin
sctri.plot_heterogeneity('leiden1','0','violin',subset=['leiden1@0','leiden3@10'],genes=['MAPK14','ANXA1'])

# sankey
sctri.plot_heterogeneity('leiden1','0','sankey')

# cellxgene
sctri.plot_heterogeneity('leiden1','0','cellxgene')

# build
sctri.plot_heterogeneity('leiden1','0','build')

'''test other plotting function'''
sctri.plot_circular_barplot('leiden1','pruned')
sctri.plot_confusion('confusion_reassign','leiden1',annot=True)
sctri.plot_cluster_feature('leiden1','0','enrichment')
sctri.plot_cluster_feature('leiden1','0','marker_genes')
sctri.plot_cluster_feature('leiden1','0','exclusive_genes')
sctri.plot_cluster_feature('leiden1','0','location')
sctri.plot_umap('confidence','continuous',umap_dot_size=10)
df = sctri.plot_winners_statistics('raw')
bucket = sctri.plot_clusterability('pruned')

# output useful intermediate result
df = sctri.get_metrics_and_shapley('TTTCTACTGAGGCA-1')
sctri.obs_to_df()
sctri.var_to_df()
sctri.gene_to_df('exclusive_genes','leiden1')
sctri.confusion_to_df('confusion_sccaf','leiden1')
sctri.display_hierarchy('pruned',True)























Loading

0 comments on commit 65530fd

Please sign in to comment.