Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single connected component and list of threshold enhancements #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 67 additions & 67 deletions Topyfic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import scanpy.external as sce
import networkx as nx
import math
from matplotlib import gridspec

from Topyfic.train import *
from Topyfic.topic import *
Expand Down Expand Up @@ -527,7 +528,7 @@ def read_analysis(file):

def compare_topModels(topModels,
output_type='graph',
threshold=0.8,
threshold=[0.8],
topModels_color=None,
topModels_label=None,
save=False,
Expand All @@ -537,13 +538,12 @@ def compare_topModels(topModels,
file_name="compare_topics"):
"""
compare several topModels

:param topModels: list of topModel class you want to compare to each other
:type topModels: list of TopModel class
:param output_type: indicate the type of output you want. graph: plot as a graph, heatmap: plot as a heatmap, table: table contains correlation
:type output_type: str
:param threshold: only apply when you choose circular which only show correlation above that
:type threshold: float
:type threshold: list of float
:param topModels_color: dictionary of colors mapping each topics to each color (default: blue)
:type topModels_color: dict
:param topModels_label: dictionary of label mapping each topics to each label
Expand All @@ -558,7 +558,6 @@ def compare_topModels(topModels,
:type plot_format: str
:param file_name: name and path of the plot use for save (default: piechart_topicAvgCell)
:type file_name: str

:return: table contains correlation between topics only when table is choose and save is False
:rtype: pandas dataframe
"""
Expand Down Expand Up @@ -609,70 +608,72 @@ def compare_topModels(topModels,

if output_type == 'graph':
np.fill_diagonal(corrs.values, 0)
corrs[corrs < threshold] = np.nan
res = corrs.stack()
res = pd.DataFrame(res)
res.reset_index(inplace=True)
res.columns = ['source', 'dest', 'weight']
res['weight'] = res['weight'].astype(float).round(decimals=2)
res['source_label'] = res['source']
res['dest_label'] = res['dest']
res['source_color'] = res['source']
res['dest_color'] = res['dest']

if topModels_label is not None:
res['source_label'].replace(topModels_label, inplace=True)
res['dest_label'].replace(topModels_label, inplace=True)
if topModels_color is None:
res['source_color'] = "blue"
res['dest_color'] = "blue"
else:
res['source_color'].replace(topModels_color, inplace=True)
res['dest_color'].replace(topModels_color, inplace=True)

G = nx.Graph()
for i in range(res.shape[0]):
G.add_node(res.source_label[i], color=res.source_color[i])
G.add_node(res.dest_label[i], color=res.dest_color[i])
G.add_edge(res.source_label[i], res.dest_label[i], weight=res.weight[i])

connected_components = sorted(nx.connected_components(G), key=len, reverse=True)

nrows = math.ceil(math.sqrt(len(threshold)))
ncols = math.ceil(len(threshold) / nrows)
outer_grid = gridspec.GridSpec(nrows, ncols, wspace=0.05, hspace=0.12)
if figsize is None:
figsize = (len(connected_components) * 2, len(connected_components) * 2)
figsize=(int(4 * nrows),int(4 * nrows))

nrows = math.ceil(math.sqrt(len(connected_components)))
ncols = math.ceil(len(connected_components) / nrows)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, facecolor='white')

i = 0
for connected_component in connected_components:
g_connected_component = G.subgraph(connected_component)
nodePos = nx.spring_layout(g_connected_component)

edge_labels = nx.get_edge_attributes(g_connected_component, "weight")

node_color = nx.get_node_attributes(g_connected_component, "color").values()
weights = nx.get_edge_attributes(g_connected_component, 'weight').values()

nx.draw_networkx(g_connected_component,
pos=nodePos,
width=list(weights),
with_labels=True,
node_color=list(node_color),
font_size=8,
node_size=500,
ax=axs[int(i / ncols), i % ncols])

nx.draw_networkx_edge_labels(g_connected_component,
nodePos,
edge_labels=edge_labels,
font_size=7,
ax=axs[int(i / ncols), i % ncols])

i += 1

[axi.axis('off') for axi in axs.ravel()]
plt.axis('off')
for i, thresholds in enumerate(threshold):
corrs.edit = corrs.copy()
corrs.edit[corrs.edit < thresholds] = np.nan
res = corrs.edit.stack()
res = pd.DataFrame(res)
res.reset_index(inplace=True)
res.columns = ['source', 'dest', 'weight']
res['weight'] = res['weight'].astype(float).round(decimals=2)
res['source_label'] = res['source']
res['dest_label'] = res['dest']
res['source_color'] = res['source']
res['dest_color'] = res['dest']

if topModels_label is not None:
res['source_label'].replace(topModels_label, inplace=True)
res['dest_label'].replace(topModels_label, inplace=True)
if topModels_color is None:
res['source_color'] = "blue"
res['dest_color'] = "blue"
else:
res['source_color'].replace(topModels_color, inplace=True)
res['dest_color'].replace(topModels_color, inplace=True)

G = nx.Graph()
for j in range(res.shape[0]):
G.add_node(res.source_label[j], color=res.source_color[j])
G.add_node(res.dest_label[j], color=res.dest_color[j])
G.add_edge(res.source_label[j], res.dest_label[j], weight=res.weight[j])

connected_components = sorted(nx.connected_components(G), key=len, reverse=True)
if len(connected_components)==1:
vis_ax =fig.add_subplot(outer_grid[int(i / ncols), i % ncols])
vis_ax.set_title('Threshold = ' + str(round(thresholds,2)),fontsize=14)


g_connected_component = G.subgraph(connected_components[0])
nodePos = nx.spring_layout(g_connected_component,k=1.2)
edge_labels = nx.get_edge_attributes(g_connected_component, "weight")
node_color = nx.get_node_attributes(g_connected_component, "color").values()
weights = nx.get_edge_attributes(g_connected_component, 'weight').values()

nx.draw_networkx(g_connected_component,pos=nodePos, width=list(weights),with_labels=True,node_color=list(node_color),font_size=12,node_size=500,ax = vis_ax)

nx.draw_networkx_edge_labels(g_connected_component,nodePos,edge_labels=edge_labels,font_size=8,ax=vis_ax)
else:
inner_grid = gridspec.GridSpecFromSubplotSpec(len(connected_components), 1,subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)
for j, connected_component in enumerate(connected_components):
vis_split = fig.add_subplot(inner_grid[j])
if j==0:
vis_split.set_title('Threshold = ' + str(round(thresholds,2)),fontsize=14)
g_connected_component = G.subgraph(connected_component)
nodePos = nx.spring_layout(g_connected_component)
edge_labels = nx.get_edge_attributes(g_connected_component, "weight")
node_color = nx.get_node_attributes(g_connected_component, "color").values()
weights = nx.get_edge_attributes(g_connected_component, 'weight').values()
nx.draw_networkx(g_connected_component,pos=nodePos,width=list(weights),with_labels=True,node_color=list(node_color),font_size=8,node_size=500,ax=vis_split)
nx.draw_networkx_edge_labels(g_connected_component,nodePos,edge_labels=edge_labels,font_size=7,ax=vis_split)
[axi.axis('off') for axi in axs.ravel()]
plt.tight_layout()
if save:
plt.savefig(f"{file_name}.{plot_format}")
Expand All @@ -681,5 +682,4 @@ def compare_topModels(topModels,
else:
plt.close()

return axs

return