-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
50 lines (42 loc) · 1.53 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from detect import Detector
import argparse
from utils import jload
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import pickle as pkl
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", help="path to saved model")
parser.add_argument(
"--device", default="cuda", type=str, help="whether to train on gpu"
)
parser.add_argument(
"--golden_testfile",
help="test set with golden syntactic and lexical signals",
)
args = parser.parse_args()
detector = Detector(args.model_path, args.device)
with open(args.golden_testfile, "rb") as f:
data = pkl.load(f)
syntactic_label, lexical_label, labels, logits = [], [], [], []
for item in tqdm(data):
predictions = detector(item["text"], False)
out = [p[1] for p in predictions]
logits += out
syntactic_label += item["syntactic"][: len(out)]
lexical_label += item["lexical"][: len(out)]
labels += item["label"][: len(out)]
syntactic_corr, _ = pearsonr(logits, syntactic_label)
lexical_corr, _ = pearsonr(logits, lexical_label)
fpr, tpr, _ = roc_curve(labels, logits, pos_label=1)
metrics = {
"syntactic_corr": syntactic_corr,
"lexical_corr": lexical_corr,
"auc": roc_auc_score(labels, logits),
"Detection Accuracy": tpr[fpr < 0.01][-1],
}
for key, value in metrics.items():
print(f"{key}: {value}")
if __name__ == "__main__":
main()