-
Notifications
You must be signed in to change notification settings - Fork 48
/
data_preprocess.py
65 lines (52 loc) · 1.98 KB
/
data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- encoding: utf-8 -*-
"""snli数据预处理"""
import time
import jsonlines
from tqdm import tqdm
def timer(func):
""" time-consuming decorator
"""
def wrapper(*args, **kwargs):
ts = time.time()
res = func(*args, **kwargs)
te = time.time()
print(f"function: `{func.__name__}` running time: {te - ts:.4f} secs")
return res
return wrapper
@timer
def snli_preprocess(src_path: str, dst_path:str) -> None:
"""处理原始的中文snli数据
Args:
src_path (str): 原始文件地址
dst_path (str): 输出文件地址
"""
# 组织数据
all_data = {}
with jsonlines.open(src_path, 'r') as reader:
for line in tqdm(reader):
sent1 = line.get('sentence1')
sent2 = line.get('sentence2')
label = line.get('gold_label')
if not sent1:
continue
if sent1 not in all_data:
all_data[sent1] = {}
if label == 'entailment':
all_data[sent1]['entailment'] = sent2
elif label == 'contradiction':
all_data[sent1]['contradiction'] = sent2
# 筛选
out_data = [
{'origin': k, 'entailment': v.get('entailment'), 'contradiction': v.get('contradiction')}
for k, v in all_data.items() if v.get('entailment') and v.get('contradiction')
]
# 写文件
with jsonlines.open(dst_path, 'w') as writer:
writer.write_all(out_data)
if __name__ == '__main__':
dev_src, dev_dst = 'datasets/cnsd-snli/cnsd_snli_v1.0.dev.jsonl', 'datasets/cnsd-snli/dev.txt'
test_src, test_dst = 'datasets/cnsd-snli/cnsd_snli_v1.0.test.jsonl', 'datasets/cnsd-snli/test.txt'
train_src, train_dst = 'datasets/cnsd-snli/cnsd_snli_v1.0.train.jsonl', 'datasets/cnsd-snli/train.txt'
snli_preprocess(train_src, train_dst)
snli_preprocess(test_src, test_dst)
snli_preprocess(dev_src, dev_dst)