-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_fact_check.py
162 lines (132 loc) · 5.94 KB
/
llm_fact_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import pickle
import os
import ast
from kg import KG
import numpy as np
import pandas as pd
import glob
from argparse import ArgumentParser
from multiprocessing import Pool
from functools import partial
from mine_llm_filtered_relation import fuzzy_matchEntities, validateRelation, paths_to_str2
from openai import OpenAI
from retrying import retry
import re
import json
import ast
@retry(stop_max_attempt_number=10, wait_fixed=0)
def call_llm(row, evidences):
claim= row.Sentence
if not args.llm_knowledge:
instruction_head = '''
You are an intelligent fact-checker. You are given a single claim and supporting evidence for the entities present in the claim, extracted from a knowledge graph.
Your task is to decide whether all the facts in the given claim are supported by given evidences..
'''
else:
instruction_head = '''
You are an intelligent fact checker trained on Wikipedia. You are given a single claim and your task is to decide whether all the facts in the given claim are supported by the given evidence using your knowledge.
'''
if not args.llm_knowledge:
content= f'''
## TASK:
Now let’s verify the Claim based on the evidences.
Claim: {claim}
Evidences:
{evidences}
'''
else:
content= f'''
## TASK:
Now let’s verify the Claim.
Claim: {claim}
'''
content+= '''
#Answer Template:
"True/False (single word answer),
One-sentence evidence."
'''
message= [{"role": "system", "content":
instruction_head + '''
Choose one of {True, False}, and output the one-sentence explanation for the choice.
'''
},{"role": "user", "content": content}]
# print(message[1]['content'])
# breakpoint()
chat_response = client.chat.completions.create(model="meta-llama/Meta-Llama-3-8B-Instruct", messages=message)
text= chat_response.choices[0].message.content
first_line = text.split("\n")[0].strip()
# make sure 'True' or 'False' is in the sub-string of first line
if not any([i in first_line for i in ["True", "False"]]):
print("retry")
raise IOError("True/False not in the first line")
output_decision = True if "True" in first_line else False
return text, output_decision
parser = ArgumentParser()
parser.add_argument("--data_path", default="/fp/projects01/ec30/factkg/full/")
parser.add_argument("--dbpedia_path",default="/fp/projects01/ec30/factkg/dbpedia/dbpedia_2015_undirected_light.pickle")
parser.add_argument("--evidence_path", default="./llm_v1_jsons", help="Path to the edvidence JSONs predicted by LLM.")
parser.add_argument("--set", choices=["test", "train", "val"], default="train")
parser.add_argument("--num_proc", type=int, default=10)
parser.add_argument("--llm_knowledge", action="store_true", help="If set, the instruction will be claim only LLM based fact checking.")
parser.add_argument("--vllm_url", default="http://g002:8000", help="URL of the vLLM server, e.g., http://g002:8000")
args = parser.parse_args()
print(args)
client = OpenAI(
api_key= "EMPTY",
base_url= args.vllm_url + "/v1",
)
kg = KG(pickle.load(open(args.dbpedia_path, 'rb')))
df = pd.read_csv(args.data_path + f'{args.set}.csv')
dfx= df
print("Total rows to process", len(dfx))
if not args.llm_knowledge:
all_evidence = {}
for file in glob.glob(f'{args.evidence_path}/llm_{args.set}/**.json'):
idx= int(file.split('/')[-1].split('.')[0])
if idx in dfx.index:
with open(file) as f:
all_evidence[idx]= json.load(f)
import multiprocessing
manager = multiprocessing.Manager()
real_predicted = manager.dict()
def process_row(index, row, _shared_dict):
if not args.llm_knowledge:
data = all_evidence[index]
true_entities = ast.literal_eval(row["Entity_set"])
predicted_entities= [k for k in data.keys() if data[k] != []]
resolved_entities = fuzzy_matchEntities(true_entities, predicted_entities, data)
resolved_entities_relation= validateRelation(resolved_entities, row, kg)
kg_results= kg.search(sorted(sorted(resolved_entities_relation.keys())), resolved_entities_relation)
supporting_evidences = "\n".join([path for typ in ["connected", "walkable"] for path in paths_to_str2(kg_results[typ])])
text, output_decision = call_llm(row, supporting_evidences)
else:
text, output_decision = call_llm(row, "")
_shared_dict[index] = [output_decision, text.replace("\n", "|")]
print(index, output_decision)
partial_process_row = partial(process_row, _shared_dict=real_predicted)
with Pool(processes=args.num_proc) as pool:
pool.starmap(partial_process_row, dfx.iterrows())
for key, value in real_predicted.items():
dfx.at[key, 'Predicted'] = value[0]
dfx.at[key, 'Response'] = value[1]
from sklearn.metrics import classification_report
print(classification_report(list(dfx["Label"].values.tolist()), list(dfx["Predicted"].values.tolist())))
dfx['Metadata'] = [ast.literal_eval(e) for e in dfx.Metatada]
interetsing = ['num1', 'multi claim', 'existence', 'multi hop']
from collections import defaultdict
interetsing_list = defaultdict(list)
for index, row in dfx.iterrows():
if "negation" in row['Metadata']:
interetsing_list['negation'].append([row['Label'], row['Predicted']])
continue
for each in interetsing:
if (each in row['Metadata']):
interetsing_list[each].append([row['Label'], row['Predicted']])
for each in interetsing_list.keys():
print(f"\nClassification report for {each}")
print(classification_report([i[0] for i in interetsing_list[each]], [i[1] for i in interetsing_list[each]]))
filtered_df = dfx[['Predicted', 'Response', 'Label']]
if_llm_knowledge = "_llm_knowledge" if args.llm_knowledge else ""
filtered_df.to_csv(f"llm_prompt_check_{args.set}{if_llm_knowledge}.csv", index=True)
print(f"saved to llm_prompt_check_{args.set}{if_llm_knowledge}.csv")
# python llm_fact_check.py --set test --num_proc 50 [--llm_knowledge]