-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse.py
111 lines (94 loc) · 4.25 KB
/
parse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import re
import time
import argparse
from tqdm import tqdm
from train import Model
from data_utils import Sentence, log_stats
from feature_generator import FeatureGenerator
from data_parser import DataParser
"""_summary_
Given a trained model file (and possibly vocabulary file) reads in CoNLL data and writes
CoNLL data where fields 7 and 8 contain dependency tree info.
parse.py should take in a sentence and from the parser configuration
generate labels, which will determine each subsequent parser configuration (from which features can
be determined). Both files, then, should point to a third library file which contains code that, given
configuration at time t ct and label l, determines ct+1.
"""
LABEL_PATTERN = r'([a-z_]+)(\(([a-z:]+)\))*'
def decompose_pred(pred_label):
pred_label = re.search(LABEL_PATTERN, pred_label)
trans_type, _, dep = pred_label.groups()
return trans_type, dep
def infer_sentence_tree(
model: Model, s: Sentence, trange: tqdm, verbose,
drop_blocking_elements
):
num_infers = 0
while True:
# getting the features of a current parse state
s_feats = FeatureGenerator.extract_features(s)
s_feats = s_feats.reshape((1, len(s_feats)))
pred_label = model.classify(s_feats)
trans_type, dep = decompose_pred(pred_label)
updated = s.update_state(curr_trans=trans_type, predicted_dep=dep)
if verbose >= 2:
trange.set_postfix({
'trans_count': f'{num_infers}/{2*len(s) + 1} [(2*tokens) + 1]',
'prev_trans': pred_label,
'stack': f'{[t.word for t in s.stack]}',
})
num_infers += 1
# Tweak for better chance on catching correct dependancies:
# Drop blocking elements and continue classification
if not updated:
if drop_blocking_elements:
if len(s.stack) > 1:
dropped_token = s.stack.pop(-1)
# assign arbitirarly something (head is previous token)
dropped_token.head = str(int(dropped_token.token_id) - 1)
elif len(s.buffer) > 0:
s.stack.append(s.buffer.pop(0))
else:
break
if s.is_exausted():
break
return s
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Parser file configuration.')
parser.add_argument('-m', default='train.model', type=str, help='model file including vocab (encoding)')
parser.add_argument('-i', default='parse.in', type=str, help='input filepath')
parser.add_argument('-o', default='parse.out', type=str, help='output filepath')
parser.add_argument('-trans', default='std', type=str, help='transition system')
parser.add_argument('-verbose', default=0, type=int, help='verbose')
parser.add_argument('-dropb', default=True, type=bool, help='whether to drop blocking elements while transiting')
args = parser.parse_args()
if args.trans == 'std':
print('Using arc-standard transition system')
else:
print('Using arc-eager transition system')
if args.dropb:
print('Dropping blocking elements')
sentences = DataParser.read_parse_tree(args.i, transition_system=args.trans)
log_stats(sentences)
# Ensure unlabeled
def unlabel_sentence(s):
for t in s.tokens:
t.head = 0
t.dep = 0
return s
# Ensure empty head and dep
sentences = list(map(lambda x: unlabel_sentence(x), sentences))
model = Model.load_model(args.m)
sentences_trange = tqdm(sentences, desc='Trees/Sentences')
if args.verbose <= 0:
print('\nStopped verbosing completely!')
sentences_trange.close()
infer_stime = time.time()
for s in sentences_trange:
infer_sentence_tree(model, s, sentences_trange,
args.verbose,
drop_blocking_elements=args.dropb)
print(f'Finshed infering sentences trees in {time.time() - infer_stime: < .2f}s')
# write CoNLL formatted file with depend tree info aka. field 7 & 8
DataParser.update_conll_file(sentences, args.i, args.o)
print(f'Finished writing updated CoNLL file as {args.o}')