-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathfilter_machine_generate.py
45 lines (35 loc) · 1.03 KB
/
filter_machine_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from utils import compute_rouge, save_results
from datasets import load_dataset
import os
import concurrent.futures
from tqdm.auto import tqdm
origin = load_dataset(
'json',
data_files = os.path.join('data', 'data_seeds.json'),
split = 'train'
)
origin_datas = []
for data in origin:
text = data['persona'] + data['situation']
origin_datas.append(text)
machine_data = load_dataset(
'json',
data_files=os.path.join('data', 'machine_generate.json'),
split='train'
)
filter_datas = []
for data in tqdm(machine_data):
text = data['persona'] + data['situation']
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = [
executor.submit(
compute_rouge,
text,
label
) for label in origin_datas
]
scores = [f.result() for f in futures]
if max(scores) <= 0.7:
filter_datas.append(data)
origin_datas.append(text)
save_results(filter_datas, 'filter_data.json')