-
Notifications
You must be signed in to change notification settings - Fork 0
/
5.calculate_mrr.py
81 lines (50 loc) · 2.52 KB
/
5.calculate_mrr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import kenlm
import json
import sqlite3
from utils import *
conn = sqlite3.connect("database.db")
cursor = conn.cursor()
def evaluate_kenlm_model(model):
processed = 0
with open("test_hashes.txt", "r") as f:
lines = f.readlines()
for line in lines:
print("Processed :", processed)
processed += 1
line = line.strip()
content = get_content_from_db(line, cursor)
data = json.loads(content)
try:
connections = data["connections"]
all_objects = data["all_objects"]
except:
connections = []
all_objects = []
if len(connections) > 0:
object_dict = create_a_dictionary_of_object_id_to_type(all_objects)
sources = [connection["patchline"]["source"][0] for connection in connections]
destinations = [connection["patchline"]["destination"][0] for connection in connections]
nodes = set(sources + destinations)
G_reversed = create_reverse_directed_graph(connections, all_objects)
mrr_for_this_graph = 0.0
for node in nodes:
# for each node, restart the algorithm from scratch
all_paths_ending_with_this_node = []
visited = {node: False for node in nodes}
current_path_for_this_node = []
three_length_dfs(node, G_reversed, visited, current_path_for_this_node, all_paths_ending_with_this_node)
true_next_word = object_dict[node]
rank = get_rank(all_paths_ending_with_this_node, model, object_dict, true_next_word, 0)
# write to a file
with open("output/" + line + ".txt", "a") as f:
f.write(node + " " + str(len(all_paths_ending_with_this_node)) + " " + object_dict[node] + " " + str(rank) + "\n")
if rank != -1:
mrr_for_this_graph += (1.0/rank)
mrr_for_this_graph /= len(nodes)
with open("mrr.txt", "a") as f:
f.write(line + " " + str(mrr_for_this_graph) + "\n")
else:
with open("exception.txt", "a") as f:
f.write(line + ": No connections found\n")
model = kenlm.Model('trained_models/kenlm_3_paths_all_not_padded.arpa')
evaluate_kenlm_model(model)