forked from InsaneLife/dssm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_input.py
137 lines (122 loc) · 4.61 KB
/
data_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# encoding=utf-8
import json
from config import Config
import numpy as np
# 配置文件
conf = Config()
def gen_word_set(file_path, out_path='./data/words.txt'):
word_set = set()
with open(file_path, encoding='utf8') as f:
for line in f.readlines():
spline = line.strip().split('\t')
if len(spline) < 4:
continue
prefix, query_pred, title, tag, label = spline
if label == '0':
continue
cur_arr = [prefix, title]
query_pred = json.loads(query_pred)
for w in prefix:
word_set.add(w)
for each in query_pred:
for w in each:
word_set.add(w)
with open(word_set, 'w', encoding='utf8') as o:
for w in word_set:
o.write(w + '\n')
pass
def convert_word2id(query, vocab_map):
ids = []
for w in query:
if w in vocab_map:
ids.append(vocab_map[w])
else:
ids.append(vocab_map[conf.unk])
while len(ids) < conf.max_seq_len:
ids.append(vocab_map[conf.pad])
return ids[:conf.max_seq_len]
def convert_seq2bow(query, vocab_map):
bow_ids = np.zeros(conf.nwords)
for w in query:
if w in vocab_map:
bow_ids[vocab_map[w]] += 1
else:
bow_ids[vocab_map[conf.unk]] += 1
return bow_ids
def get_data(file_path):
"""
gen datasets, convert word into word ids.
:param file_path:
:return: [[query, pos sample, 4 neg sample]], shape = [n, 6]
"""
data_map = {'query': [], 'query_len': [], 'doc_pos': [], 'doc_pos_len': [], 'doc_neg': [], 'doc_neg_len': []}
with open(file_path, encoding='utf8') as f:
for line in f.readlines():
spline = line.strip().split('\t')
if len(spline) < 4:
continue
prefix, query_pred, title, tag, label = spline
if label == '0':
continue
cur_arr, cur_len = [], []
query_pred = json.loads(query_pred)
# only 4 negative sample
for each in query_pred:
if each == title:
continue
cur_arr.append(convert_word2id(each, conf.vocab_map))
each_len = len(each) if len(each) < conf.max_seq_len else conf.max_seq_len
cur_len.append(each_len)
if len(cur_arr) >= 4:
data_map['query'].append(convert_word2id(prefix, conf.vocab_map))
data_map['query_len'].append(len(prefix) if len(prefix) < conf.max_seq_len else conf.max_seq_len)
data_map['doc_pos'].append(convert_word2id(title, conf.vocab_map))
data_map['doc_pos_len'].append(len(title) if len(title) < conf.max_seq_len else conf.max_seq_len)
data_map['doc_neg'].extend(cur_arr[:4])
data_map['doc_neg_len'].extend(cur_len[:4])
pass
return data_map
def get_data_siamese_rnn(file_path):
"""
gen datasets, convert word into word ids.
:param file_path:
:return: [[query, pos sample, 4 neg sample]], shape = [n, 6]
"""
data_arr = []
with open(file_path, encoding='utf8') as f:
for line in f.readlines():
spline = line.strip().split('\t')
if len(spline) < 4:
continue
prefix, _, title, tag, label = spline
prefix_seq = convert_word2id(prefix, conf.vocab_map)
title_seq = convert_word2id(title, conf.vocab_map)
data_arr.append([prefix_seq, title_seq, int(label)])
return data_arr
def get_data_bow(file_path):
"""
gen datasets, convert word into word ids.
:param file_path:
:return: [[query, prefix, label]], shape = [n, 3]
"""
data_arr = []
with open(file_path, encoding='utf8') as f:
for line in f.readlines():
spline = line.strip().split('\t')
if len(spline) < 4:
continue
prefix, _, title, tag, label = spline
prefix_ids = convert_seq2bow(prefix, conf.vocab_map)
title_ids = convert_seq2bow(title, conf.vocab_map)
data_arr.append([prefix_ids, title_ids, int(label)])
return data_arr
if __name__ == '__main__':
# prefix, query_prediction, title, tag, label
# query_prediction 为json格式。
file_train = './data/oppo_round1_train_20180929.txt'
file_vali = './data/oppo_round1_vali_20180929.txt'
# data_train = get_data(file_train)
data_train = get_data(file_vali)
print(len(data_train['query']), len(data_train['doc_pos']), len(data_train['doc_neg']))
pass