forked from Tiiiger/bert_score
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtune_layers.py
116 lines (95 loc) · 3.95 KB
/
tune_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import re
import argparse
import torch
import numpy as np
from tqdm.auto import tqdm, trange
from collections import defaultdict
from scipy.stats import pearsonr
import bert_score
def get_wmt16(lang_pair, data_folder="wmt16"):
with open(
os.path.join(
data_folder,
f"wmt16-metrics-results/seg-level-results/DAseg-newstest2016/DAseg-newstest2016.human.{lang_pair}",
)
) as f:
gold_scores = list(map(float, f.read().strip().split("\n")))
with open(
os.path.join(
data_folder,
f"wmt16-metrics-results/seg-level-results/DAseg-newstest2016/DAseg-newstest2016.reference.{lang_pair}",
)
) as f:
all_refs = f.read().strip().split("\n")
with open(
os.path.join(
data_folder,
f"wmt16-metrics-results/seg-level-results/DAseg-newstest2016/DAseg-newstest2016.mt-system.{lang_pair}",
)
) as f:
all_hyps = f.read().strip().split("\n")
return gold_scores, all_refs, all_hyps
def get_wmt16_seg_to_bert_score(lang_pair, scorer, data_folder="wmt16", batch_size=64):
# os.makedirs(f"cache_score/{network}", exist_ok=True)
# path = "cache_score/{}/wmt16_seg_to_{}_{}.pkl".format(network, *lang_pair.split("-"))
gold_scores, refs, cands = get_wmt16(lang_pair, data_folder=data_folder)
if scorer.idf:
scorer.compute_idf(refs)
scores = scorer.score(cands, refs, verbose=False, batch_size=batch_size)
scores = list(scores)
max_length = scorer._tokenizer.max_len_single_sentence
return scores, gold_scores, max_length
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data", default="wmt16", help="path to wmt16 data")
parser.add_argument("-m", "--model", nargs="+", help="models to tune")
parser.add_argument("-l", "--log_file", default="best_layers_log.txt", help="log file path")
parser.add_argument("--idf", action="store_true")
parser.add_argument("-b", "--batch_size", type=int, default=64)
parser.add_argument(
"--lang_pairs",
nargs="+",
default=["cs-en", "de-en", "fi-en", "ro-en", "ru-en", "tr-en"],
help="language pairs used for tuning",
)
args = parser.parse_args()
if args.log_file.endswith(".txt"):
csv_file = args.log_file.replace(".txt", ".csv")
else:
csv_file = args.log_file + ".csv"
torch.set_grad_enabled(False)
networks = args.model
for network in networks:
model_type = network
scorer = bert_score.scorer.BERTScorer(model_type=model_type, num_layers=100, idf=False, all_layers=True)
results = defaultdict(dict)
for lang_pair in tqdm(args.lang_pairs):
scores, gold_scores, max_length = get_wmt16_seg_to_bert_score(lang_pair, scorer, batch_size=args.batch_size)
for i, score in enumerate(scores[2]):
results[lang_pair + " " + str(i)]["%s %s" % (network, "F")] = pearsonr(score, gold_scores)[0]
best_layer, best_corr = 0, 0.0
for num_layer in range(100):
temp = []
if f"{args.lang_pairs[0]} {num_layer}" not in results:
break
for lp in args.lang_pairs:
temp.append(results[f"{lp} {num_layer}"][f"{network} F"])
corr = np.mean(temp)
results["avg" + " " + str(num_layer)]["%s %s" % (network, "F")] = corr
print(network, num_layer, corr)
if corr > best_corr:
best_layer, best_corr = num_layer, corr
if args.idf:
msg = f"'{network}' (idf): {best_layer}, # {best_corr}"
else:
msg = f"'{network}': {best_layer}, # {best_corr}"
print(msg)
with open(args.log_file, "a") as f:
print(msg, file=f)
csv_msg = f"{network},{best_layer},{best_corr},,{max_length}"
with open(csv_file, "a") as f:
print(csv_msg, file=f)
del scorer
if __name__ == "__main__":
main()