From b989c0dd35877119becfd9120977398aa275035f Mon Sep 17 00:00:00 2001 From: Nick Gomez Date: Tue, 9 May 2023 09:49:09 -0700 Subject: [PATCH] Fixed single connected component error and added functionality to use a list of thresholds instead of single value --- Topyfic/utils.py | 134 +++++++++++++++++++++++------------------------ 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/Topyfic/utils.py b/Topyfic/utils.py index c5148d7..81db2c9 100644 --- a/Topyfic/utils.py +++ b/Topyfic/utils.py @@ -15,6 +15,7 @@ import scanpy.external as sce import networkx as nx import math +from matplotlib import gridspec from Topyfic.train import * from Topyfic.topic import * @@ -527,7 +528,7 @@ def read_analysis(file): def compare_topModels(topModels, output_type='graph', - threshold=0.8, + threshold=[0.8], topModels_color=None, topModels_label=None, save=False, @@ -537,13 +538,12 @@ def compare_topModels(topModels, file_name="compare_topics"): """ compare several topModels - :param topModels: list of topModel class you want to compare to each other :type topModels: list of TopModel class :param output_type: indicate the type of output you want. graph: plot as a graph, heatmap: plot as a heatmap, table: table contains correlation :type output_type: str :param threshold: only apply when you choose circular which only show correlation above that - :type threshold: float + :type threshold: list of float :param topModels_color: dictionary of colors mapping each topics to each color (default: blue) :type topModels_color: dict :param topModels_label: dictionary of label mapping each topics to each label @@ -558,7 +558,6 @@ def compare_topModels(topModels, :type plot_format: str :param file_name: name and path of the plot use for save (default: piechart_topicAvgCell) :type file_name: str - :return: table contains correlation between topics only when table is choose and save is False :rtype: pandas dataframe """ @@ -609,70 +608,72 @@ def compare_topModels(topModels, if output_type == 'graph': np.fill_diagonal(corrs.values, 0) - corrs[corrs < threshold] = np.nan - res = corrs.stack() - res = pd.DataFrame(res) - res.reset_index(inplace=True) - res.columns = ['source', 'dest', 'weight'] - res['weight'] = res['weight'].astype(float).round(decimals=2) - res['source_label'] = res['source'] - res['dest_label'] = res['dest'] - res['source_color'] = res['source'] - res['dest_color'] = res['dest'] - - if topModels_label is not None: - res['source_label'].replace(topModels_label, inplace=True) - res['dest_label'].replace(topModels_label, inplace=True) - if topModels_color is None: - res['source_color'] = "blue" - res['dest_color'] = "blue" - else: - res['source_color'].replace(topModels_color, inplace=True) - res['dest_color'].replace(topModels_color, inplace=True) - - G = nx.Graph() - for i in range(res.shape[0]): - G.add_node(res.source_label[i], color=res.source_color[i]) - G.add_node(res.dest_label[i], color=res.dest_color[i]) - G.add_edge(res.source_label[i], res.dest_label[i], weight=res.weight[i]) - - connected_components = sorted(nx.connected_components(G), key=len, reverse=True) - + nrows = math.ceil(math.sqrt(len(threshold))) + ncols = math.ceil(len(threshold) / nrows) + outer_grid = gridspec.GridSpec(nrows, ncols, wspace=0.05, hspace=0.12) if figsize is None: - figsize = (len(connected_components) * 2, len(connected_components) * 2) + figsize=(int(4 * nrows),int(4 * nrows)) - nrows = math.ceil(math.sqrt(len(connected_components))) - ncols = math.ceil(len(connected_components) / nrows) fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, facecolor='white') - - i = 0 - for connected_component in connected_components: - g_connected_component = G.subgraph(connected_component) - nodePos = nx.spring_layout(g_connected_component) - - edge_labels = nx.get_edge_attributes(g_connected_component, "weight") - - node_color = nx.get_node_attributes(g_connected_component, "color").values() - weights = nx.get_edge_attributes(g_connected_component, 'weight').values() - - nx.draw_networkx(g_connected_component, - pos=nodePos, - width=list(weights), - with_labels=True, - node_color=list(node_color), - font_size=8, - node_size=500, - ax=axs[int(i / ncols), i % ncols]) - - nx.draw_networkx_edge_labels(g_connected_component, - nodePos, - edge_labels=edge_labels, - font_size=7, - ax=axs[int(i / ncols), i % ncols]) - - i += 1 - - [axi.axis('off') for axi in axs.ravel()] + plt.axis('off') + for i, thresholds in enumerate(threshold): + corrs.edit = corrs.copy() + corrs.edit[corrs.edit < thresholds] = np.nan + res = corrs.edit.stack() + res = pd.DataFrame(res) + res.reset_index(inplace=True) + res.columns = ['source', 'dest', 'weight'] + res['weight'] = res['weight'].astype(float).round(decimals=2) + res['source_label'] = res['source'] + res['dest_label'] = res['dest'] + res['source_color'] = res['source'] + res['dest_color'] = res['dest'] + + if topModels_label is not None: + res['source_label'].replace(topModels_label, inplace=True) + res['dest_label'].replace(topModels_label, inplace=True) + if topModels_color is None: + res['source_color'] = "blue" + res['dest_color'] = "blue" + else: + res['source_color'].replace(topModels_color, inplace=True) + res['dest_color'].replace(topModels_color, inplace=True) + + G = nx.Graph() + for j in range(res.shape[0]): + G.add_node(res.source_label[j], color=res.source_color[j]) + G.add_node(res.dest_label[j], color=res.dest_color[j]) + G.add_edge(res.source_label[j], res.dest_label[j], weight=res.weight[j]) + + connected_components = sorted(nx.connected_components(G), key=len, reverse=True) + if len(connected_components)==1: + vis_ax =fig.add_subplot(outer_grid[int(i / ncols), i % ncols]) + vis_ax.set_title('Threshold = ' + str(round(thresholds,2)),fontsize=14) + + + g_connected_component = G.subgraph(connected_components[0]) + nodePos = nx.spring_layout(g_connected_component,k=1.2) + edge_labels = nx.get_edge_attributes(g_connected_component, "weight") + node_color = nx.get_node_attributes(g_connected_component, "color").values() + weights = nx.get_edge_attributes(g_connected_component, 'weight').values() + + nx.draw_networkx(g_connected_component,pos=nodePos, width=list(weights),with_labels=True,node_color=list(node_color),font_size=12,node_size=500,ax = vis_ax) + + nx.draw_networkx_edge_labels(g_connected_component,nodePos,edge_labels=edge_labels,font_size=8,ax=vis_ax) + else: + inner_grid = gridspec.GridSpecFromSubplotSpec(len(connected_components), 1,subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0) + for j, connected_component in enumerate(connected_components): + vis_split = fig.add_subplot(inner_grid[j]) + if j==0: + vis_split.set_title('Threshold = ' + str(round(thresholds,2)),fontsize=14) + g_connected_component = G.subgraph(connected_component) + nodePos = nx.spring_layout(g_connected_component) + edge_labels = nx.get_edge_attributes(g_connected_component, "weight") + node_color = nx.get_node_attributes(g_connected_component, "color").values() + weights = nx.get_edge_attributes(g_connected_component, 'weight').values() + nx.draw_networkx(g_connected_component,pos=nodePos,width=list(weights),with_labels=True,node_color=list(node_color),font_size=8,node_size=500,ax=vis_split) + nx.draw_networkx_edge_labels(g_connected_component,nodePos,edge_labels=edge_labels,font_size=7,ax=vis_split) + [axi.axis('off') for axi in axs.ravel()] plt.tight_layout() if save: plt.savefig(f"{file_name}.{plot_format}") @@ -681,5 +682,4 @@ def compare_topModels(topModels, else: plt.close() - return axs - + return