forked from julesha/sentiment-argument-mining
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_UNSC.py
213 lines (191 loc) · 7.53 KB
/
predict_UNSC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import argparse
import logging
import logging.config
logging.config.fileConfig('./utils/logging.conf')
import numpy as np
from tqdm import tqdm
from glob import glob
from collections import Counter
from tensorflow.keras.models import load_model
from train_USElectionDebates import read_data_US
from pre_process_UNSC import *
from utils.arg_metav_formatter import *
from utils.model_utils import *
def read_data_UNSC(max_seq_length=512, directory="./data/UNSC/pred/"):
"""
Function reads UNSC corpus data
Args:
max_seq_length (int): maximum sequence length to be used in training
directory (str): directory to find files
Returns:
pred_tokens (dict): mapping between unique UNSC speech IDs and
tokenized input
pred_X (np.ndarray): input albert token IDs
pred_mask (np.ndarray): input mask indicating which token is relevant
to outcome, this includes all corpus tokens and excludes
all bert special tokens
"""
check = glob(os.path.join(directory, "*" + str(max_seq_length) + "*"))
if len(check) < 3:
raise FileNotFoundError("Preprocessed UNSC data not "
"found, please preprocess data first")
else:
with open(
os.path.join(directory,
"pred_tokens_" + str(max_seq_length) + ".json"),
"r") as f:
pred_tokens = json.load(f)
pred_X = np.load(
os.path.join(directory, "pred_X_" + str(max_seq_length) + ".npy"))
pred_mask = np.load(
os.path.join(directory,
"pred_mask_" + str(max_seq_length) + ".npy"))
return pred_tokens, pred_X, pred_mask
def load_saved_model(model_path):
"""
Function to load a saved keras model
Args:
model_path (str): path to *h5 keras model
Returns:
model (tensorflow.python.keras.engine.training.Model): saved keras model
"""
l_bert, model_ckpt = fetch_bert_layer()
model = load_model(model_path,
custom_objects={
"BertModelLayer": l_bert,
"argument_candidate_acc": class_acc(3)
})
return model
def pred_model_UNSC(model_path,
max_seq_length=512,
direct_save="./data/UNSC/pred/",
force_pred=False):
"""
Predict given saved model on UNSC corpus
Args:
model_path (str): path to *h5 keras model
max_seq_length (int): maximum sequence length to be used in training
direct_save (str): directory where to save predictions
force_pred (bool): whether to forcefully predict when an cached
prediction already exists
Returns:
y_pred (np.ndarray): model predictions on UNSC corpus
"""
if not force_pred and os.path.isfile("./data/UNSC/pred/pred_Yhat_" +
str(max_seq_length) + ".npy"):
y_pred = np.load("./data/UNSC/pred/pred_Yhat_" + str(max_seq_length) +
".npy")
else:
_, pred_X, _ = read_data_UNSC(max_seq_length)
model = load_saved_model(model_path)
y_pred = model.predict(pred_X, batch_size=128)
y_pred = np.argmax(y_pred, axis=-1)
np.save(
os.path.join(direct_save,
"pred_Yhat_" + str(max_seq_length) + ".npy"), y_pred)
return y_pred
def summary_info_UNSC_pred(collection,
max_seq_length=512,
directory="./data/UNSC/pred/"):
"""
Function to write summary statistics on token types to file
Args:
collection (list): data containing token and types
max_seq_length (int): maximum sequence length to be used in training
directory (str): directory to output files
"""
new_collection = []
# get respective token counts
for i, el in enumerate(list(collection.keys())):
new_collection.append([el])
tmp = []
for sub_el in collection[el]:
tmp.append(sub_el[1])
local_dict = dict(Counter(tmp))
try:
N = local_dict["N"]
except KeyError:
N = 0
try:
C = local_dict["C"]
except KeyError:
C = 0
try:
P = local_dict["P"]
except KeyError:
P = 0
new_collection[i].extend([N, C, P])
# write to csv file
with open(
os.path.join(directory,
"pred_tokens_stats_" + str(max_seq_length) + ".csv"),
"w") as f:
writer = csv.writer(f)
writer.writerow(["speech", "N", "C", "P"])
writer.writerows(new_collection)
def simplify_results(y_pred,
max_seq_length=512,
directory="./data/UNSC/pred/"):
"""
Simplify model predictions on UNSC corpus to human readable format
Args:
y_pred (np.ndarray): model predictions on UNSC corpus
max_seq_length (int): maximum sequence length to be used in training
directory (str): directory where to save results
Returns:
clean_results (dict): simplified dictionary mapping from speech ID's
to saved model predictions
"""
pred_tokens, _, pred_mask = read_data_UNSC(max_seq_length)
_, _, _, _, label_map = read_data_US(max_seq_length)
label_map_inverse = {item[1]: item[0] for item in label_map.items()}
keys = list(pred_tokens.keys())
clean_results = {}
for i in tqdm(range(pred_mask.shape[0])):
clean_results[keys[i]] = [
(pred_tokens[keys[i]][j], label_map_inverse[y_pred[i, j]])
for j, binary in enumerate(pred_mask[i].tolist()) if binary == 1
]
with open(
os.path.join(directory,
"pred_clean_" + str(max_seq_length) + ".json"),
"w") as f:
json.dump(clean_results, f, ensure_ascii=False)
# execute pipeline to get summary info
summary_info_UNSC_pred(clean_results)
return clean_results
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=arg_metav_formatter)
parser.add_argument("--max-seq-length",
type=int,
default=512,
help="maximum sequence length of tokenized id's")
parser.add_argument("--force-pred",
action="store_true",
default=False,
help="option to force redoing prediction despite" +
" presence of already produced binary")
parser.add_argument("--verbosity",
type=int,
default=1,
help="0 for no text, 1 for verbose text")
required = parser.add_argument_group("required name arguments")
required.add_argument("--model",
required=True,
type=str,
help="path to model *h5 file")
args = parser.parse_args()
if args.verbosity == 1:
logger = logging.getLogger('base')
else:
logger = logging.getLogger('root')
logger.info("Loading model predictions, might take some time...")
y_pred = pred_model_UNSC(model_path=args.model,
max_seq_length=args.max_seq_length,
force_pred=args.force_pred)
logger.info("Simplifying model predictions for human readability")
clean_results = simplify_results(y_pred, args.max_seq_length)