forked from toolleeo/special-section-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
clusterAfterMatrix.py
101 lines (68 loc) · 2.67 KB
/
clusterAfterMatrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import networkx as nx
from cdlib import algorithms, viz
from scipy.stats import norm
import statistics
def edges_calc(file):
df = pd.read_csv(file, sep='\t')
df.drop(df[df['total'] == 1].index, inplace=True) #removing rows with no connections between tracks
name_columns = df.columns.tolist()
name_columns = name_columns[2:len(name_columns)-1] #to get only the tracks
df = df[name_columns]
list_columns= []
for column in name_columns:
list_columns.append(df[column].tolist())
dictionary_columns = {a: b for a, b in zip(name_columns, list_columns)}
graph_edges = []
edges_weights = []
intersection_values = []
for track1, track2 in itertools.combinations(dictionary_columns.keys(), 2):
intersection = 0
for i, j in zip(dictionary_columns.get(track1), dictionary_columns.get(track2)):
if i >= 2 and j >= 2:
intersection+=1
intersection_values.append(intersection)
if intersection > 13 + 1 * 10: #mean + 1 * sd
graph_edges.append((track1, track2))
edges_weights.append(intersection)
return name_columns, graph_edges, edges_weights, intersection_values
def nodes_calc(name_columns):
nodes = []
for node in name_columns:
nodes.append(node)
return nodes
def create_graph(nodes, edges, weights):
plt.figure(figsize = (15, 10))
g = nx.Graph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)
coms = algorithms.leiden(g, weights=weights)
pos = nx.spring_layout(g, k=1)
viz.plot_network_clusters(g, coms, pos, plot_labels=True, node_size=900)
plt.savefig('clusterOnMatrix.svg', format="svg")
plt.clf()
def create_intersectionHistogram(intersection_values):
mean = statistics.mean(intersection_values)
sd = statistics.stdev(intersection_values)
print(mean, sd)
mean, sd = norm.fit(intersection_values)
print(mean, sd)
plt.hist(intersection_values, density=True, bins=50, alpha=0.5)
plt.plot(intersection_values, norm.pdf(intersection_values, mean, sd))
plt.xlabel('intersections')
plt.savefig('after_matrix_intersection_histogram.png')
plt.clf()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='matrix')
args = parser.parse_args()
name_columns, graph_edges, edges_weights, intersection_values = edges_calc(args.file)
nodes = nodes_calc(name_columns)
create_graph(nodes, graph_edges, edges_weights)
intersection_values.sort()
create_intersectionHistogram(intersection_values)
if __name__ == '__main__':
main()