-
Notifications
You must be signed in to change notification settings - Fork 60
/
test.py
90 lines (72 loc) · 3.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
import os
import csv
import numpy as np
import pickle as pkl
import tensorflow as tf
from tensorflow.contrib import learn
import data_helper
# Show warnings and errors only
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# File paths
tf.flags.DEFINE_string('test_data_file', None, 'Test data file path')
tf.flags.DEFINE_string('run_dir', None, 'Restore the model from this run')
tf.flags.DEFINE_string('checkpoint', None, 'Restore the graph from this checkpoint')
# Test batch size
tf.flags.DEFINE_integer('batch_size', 64, 'Test batch size')
FLAGS = tf.app.flags.FLAGS
# Restore parameters
with open(os.path.join(FLAGS.run_dir, 'params.pkl'), 'rb') as f:
params = pkl.load(f, encoding='bytes')
# Restore vocabulary processor
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(os.path.join(FLAGS.run_dir, 'vocab'))
# Load test data
data, labels, lengths, _ = data_helper.load_data(file_path=FLAGS.test_data_file,
sw_path=params['stop_word_file'],
min_frequency=params['min_frequency'],
max_length=params['max_length'],
language=params['language'],
vocab_processor=vocab_processor,
shuffle=False)
# Restore graph
graph = tf.Graph()
with graph.as_default():
sess = tf.Session()
# Restore metagraph
saver = tf.train.import_meta_graph('{}.meta'.format(os.path.join(FLAGS.run_dir, 'model', FLAGS.checkpoint)))
# Restore weights
saver.restore(sess, os.path.join(FLAGS.run_dir, 'model', FLAGS.checkpoint))
# Get tensors
input_x = graph.get_tensor_by_name('input_x:0')
input_y = graph.get_tensor_by_name('input_y:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')
predictions = graph.get_tensor_by_name('softmax/predictions:0')
accuracy = graph.get_tensor_by_name('accuracy/accuracy:0')
# Generate batches
batches = data_helper.batch_iter(data, labels, lengths, FLAGS.batch_size, 1)
num_batches = int(len(data)/FLAGS.batch_size)
all_predictions = []
sum_accuracy = 0
# Test
for batch in batches:
x_test, y_test, x_lengths = batch
if params['clf'] == 'cnn':
feed_dict = {input_x: x_test, input_y: y_test, keep_prob: 1.0}
batch_predictions, batch_accuracy = sess.run([predictions, accuracy], feed_dict)
else:
batch_size = graph.get_tensor_by_name('batch_size:0')
sequence_length = graph.get_tensor_by_name('sequence_length:0')
feed_dict = {input_x: x_test, input_y: y_test, batch_size: FLAGS.batch_size, sequence_length: x_lengths, keep_prob: 1.0}
batch_predictions, batch_accuracy = sess.run([predictions, accuracy], feed_dict)
sum_accuracy += batch_accuracy
all_predictions = np.concatenate([all_predictions, batch_predictions])
final_accuracy = sum_accuracy / num_batches
# Print test accuracy
print('Test accuracy: {}'.format(final_accuracy))
# Save all predictions
with open(os.path.join(FLAGS.run_dir, 'predictions.csv'), 'w', encoding='utf-8', newline='') as f:
csvwriter = csv.writer(f)
csvwriter.writerow(['True class', 'Prediction'])
for i in range(len(all_predictions)):
csvwriter.writerow([labels[i], all_predictions[i]])
print('Predictions saved to {}'.format(os.path.join(FLAGS.run_dir, 'predictions.csv')))