-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
94 lines (88 loc) · 3.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from tqdm import tqdm
import os
import numpy as np
from random import choice
from itertools import groupby
import pickle
mode = 0
min_count = 2
char_size = 128
def process_kb_data():
id2kb = {}
with open('./data/ccks2019_el/kb_data') as f:
for l in tqdm(f):
_ = json.loads(l)
subject_id = _['subject_id']
subject_alias = list(set([_['subject']] + _.get('alias', [])))
subject_alias = [alias.lower() for alias in subject_alias]
subject_desc = '\n'.join(u'%s:%s' % (i['predicate'], i['object']) for i in _['data'])
subject_desc = subject_desc.lower()
if subject_desc:
id2kb[subject_id] = {'subject_alias': subject_alias, 'subject_desc': subject_desc}
kb2id = {}
for i,j in id2kb.items():
for k in j['subject_alias']:
if k not in kb2id:
kb2id[k] = []
kb2id[k].append(i)
#
for i,j in id2kb.items():
if i=="311223":
print(i,j)
for i,j in kb2id.items():
if i == '南京南站':
print(i, j)
return id2kb,kb2id
def read_train(path=''):
id2kb, kb2id=process_kb_data()
chars={}
all_alies = []
with open('./data/ccks2019_el/train.json') as f:
for l in tqdm(f):
train_text=json.loads(l)
text=train_text['text']
for c in text:
chars[c] = chars.get(c, 0) + 1
s1 = [0] * len(text)
s2 = [0] * len(text)
temp=list()
for x in train_text['mention_data']:
if x['kb_id']!='NIL':
try:
kb_id=x['kb_id']
name=x['mention']
begin=int(x['offset'])
end=begin+len(name)-1
s1[begin]=1
s2[end]=1
y = [0] * len(text)
for i in range(begin, end+1):
y[i] = 1
name_ids=kb2id.get(name)
for i in name_ids:
if kb_id==i:
temp.append([text,name,y,'1',id2kb.get(i)['subject_desc']])
else:
temp.append([text, name,y,'0', id2kb.get(i)['subject_desc']])
except:
temp.append([text,name,y,'1',id2kb.get(kb_id)['subject_desc']])
all_alies.extend([i + [s1] + [s2] for i in temp])
return all_alies,id2kb,kb2id,chars
def save_file_to_file():
all_alies,id2kb,kb2id,chars= read_train()
for d in tqdm(iter(id2kb.values())):
for c in d['subject_desc']:
chars[c] = chars.get(c, 0) + 1
chars = {i: j for i, j in chars.items() if j >= min_count}
id2char = {i + 2: j for i, j in enumerate(chars)}
char2id = {j: i for i, j in id2char.items()}
with open('./data/kb.pkl', 'wb') as fw:
pickle.dump(id2kb, fw)
pickle.dump(kb2id, fw)
pickle.dump(id2char, fw)
pickle.dump(char2id, fw)
with open('./data/train.pkl', 'wb') as fw:
pickle.dump(all_alies, fw)
print('finish!')
# save_file_to_file()