-
Notifications
You must be signed in to change notification settings - Fork 3
/
accuracy.py
138 lines (93 loc) · 4.13 KB
/
accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Calculate the accuracy of sampled subgraph pattern frequencies.
A command-line tool that takes the discovered subgraph patterns
from exact counting and reservoir sampling runs on a graph and
compares their frequencies.
The resulting Precision, Recall and Average Relative Error values
reflect how accurately the reservoir sampling scheme maintains
the distribution of different subgraph patterns when compared
to the exact counting scheme.
"""
import csv
import math
import pprint
from collections import Counter
from argparse import ArgumentParser, FileType
def parse_patterns_file(patterns_file, runs):
patterns = [Counter() for i in range(runs)]
with patterns_file as file:
reader = csv.DictReader(file, delimiter=' ')
for row in reader:
for i in range(runs):
patterns[i][row['canonical_label']] = int(row['count_%d' % (i + 1)])
return patterns
def pattern_frequencies(pattern_counts):
"""Calculate the relative frequency of each pattern."""
N = float(sum(pattern_counts.values()))
return {ptrn: count / N for ptrn, count in pattern_counts.items()}
def threshold_frequencies(pattern_freqs, tau):
"""Filter for pattern frequencies where freq >= τ."""
return {ptrn: freq for ptrn, freq in pattern_freqs.items() if freq >= tau}
def precision(exact_patterns, sampled_patterns):
exact_patterns = set(exact_patterns)
sampled_patterns = set(sampled_patterns)
if len(sampled_patterns) == 0:
return int(len(exact_patterns) == 0)
return len(exact_patterns & sampled_patterns) / float(len(sampled_patterns))
def recall(exact_patterns, sampled_patterns):
exact_patterns = set(exact_patterns)
sampled_patterns = set(sampled_patterns)
if len(exact_patterns) == 0:
return int(len(sampled_patterns) == 0)
return len(exact_patterns & sampled_patterns) / float(len(exact_patterns))
def avg_relative_error(exact_patterns, sampled_patterns, T_k):
are = 0
for pattern in exact_patterns:
p_i = exact_patterns[pattern]
q_i = sampled_patterns[pattern] if pattern in sampled_patterns else 0
are += abs(q_i - p_i) / p_i
return are / float(T_k)
def main():
parser = ArgumentParser(description="Calculate accuracy of FSM sampling runs.")
parser.add_argument('exact_patterns_file',
type=FileType('r'),
help="path to the file that contains exact counting patterns")
parser.add_argument('sampled_patterns_file',
type=FileType('r'),
help="path to the file that contains reservoir sampling patterns")
parser.add_argument('T_k',
type=int,
help="number of unique subgraph patterns")
parser.add_argument('-t', '--tau',
type=float,
default=0.001,
help="coefficient to multiply frequency thresholds (default 0.001)")
parser.add_argument('-r', '--runs',
type=int,
default=5,
help="number of runs provided for reservoir sampling (default 5)")
args = vars(parser.parse_args())
T_k = args['T_k']
runs = args['runs']
tau_coefficient = args['tau']
exact_pattern_counts = parse_patterns_file(args['exact_patterns_file'], 1)[0]
exact_pattern_freqs = pattern_frequencies(exact_pattern_counts)
sampled_pattern_counts = parse_patterns_file(args['sampled_patterns_file'], runs)
sampled_pattern_freqs = [pattern_frequencies(c) for c in sampled_pattern_counts]
for threshold in [0.001, 0.01, 0.1, 0.2, 1, 2, 10]:
tau = threshold * tau_coefficient
print("\nThreshold", tau)
exact_patterns = threshold_frequencies(exact_pattern_freqs, tau)
are = 0
prec = 0
rec = 0
for pattern_freqs in sampled_pattern_freqs:
sampled_patterns = threshold_frequencies(pattern_freqs, tau)
are += avg_relative_error(exact_patterns, sampled_patterns, T_k)
prec += precision(exact_patterns, sampled_patterns)
rec += recall(exact_patterns, sampled_patterns)
print("ARE :", are / float(runs))
print("precision:", prec / float(runs))
print("recall :", rec / float(runs))
if __name__ == '__main__':
main()