-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpredict.py
105 lines (88 loc) · 4.19 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import tensorflow as tf
import numpy as np
import pickle
# Used for reliably getting the current hostname.
import socket
import time
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
tf.flags.DEFINE_integer("batch_size", 1, "Batch Size (default: 1)")
# Example: './data/runs/euler/local-w2v-275d-1466050948/checkpoints/model-96690'
tf.flags.DEFINE_string("checkpoint_file", None, "Checkpoint file from the training run.")
tf.flags.DEFINE_string(
"validation_data_fname",
"./data/preprocessing/validateX.npy",
"The numpy dump of the validation data for Kaggle. Should ideally be"
" preprocessed the same way as the training data.")
tf.flags.DEFINE_string(
"input_x_name",
"input_x",
"The graph node name of the input data. Hint: if you forget to name it,"
" it's probably called 'Placeholder'.")
tf.flags.DEFINE_string(
"predictions_name",
"output/predictions",
"The graph node name of the prediction computation. Hint: if you forget to"
" name it, it's probably called 'Softmax' or 'output/Softmax'.")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
if FLAGS.checkpoint_file is None:
raise ValueError("Please specify a TensorFlow checkpoint file to use for"
" making the predictions (--checkpoint_file <file>).")
validation_data_fname = FLAGS.validation_data_fname
print("Validation data file: {0}".format(validation_data_fname))
validation_data = np.load(validation_data_fname)
checkpoint_file = FLAGS.checkpoint_file
timestamp = int(time.time())
filename = "./data/output/prediction_cnn_{0}.csv".format(timestamp)
meta_filename = "{0}.meta".format(filename)
print("Predicting using checkpoint file [{0}].".format(checkpoint_file))
print("Will write predictions to file [{0}].".format(filename))
print("Validation data shape: {0}".format(validation_data.shape))
graph = tf.Graph()
with graph.as_default():
# TODO(andrei): Is this config (and its associated flags) really necessary?
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
print("Loading saved meta graph...")
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
print("Restoring variables...")
saver.restore(sess, checkpoint_file)
print("Finished TF graph load.")
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name(FLAGS.input_x_name).outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name(FLAGS.predictions_name).outputs[0]
# Collect the predictions here
all_predictions = []
print("Computing predictions...")
for (id, row) in enumerate(validation_data):
if (id + 1) % 1000 == 0:
print("Done tweets: {0}/{1}".format(id + 1, len(validation_data)))
prediction = sess.run(predictions, {
input_x: [row],
dropout_keep_prob: 1.0
})[0]
all_predictions.append((id + 1, prediction))
print("Prediction done.")
print("Writing predictions to file...")
submission = open(filename, 'w+')
print('Id,Prediction', file=submission)
# Ensure that IDs are from 1 to 10000, NOT from 0. Otherwise Kaggle
# rejects the submission.
for id, pred in all_predictions:
if pred[0] >= 0.5:
print("%d,-1" % id,file=submission)
else:
print("%d,1" % id,file=submission)
with open(meta_filename, 'w') as mf:
print("Generated from checkpoint: {0}".format(checkpoint_file), file=mf)
print("Hostname: {0}".format(socket.gethostname()), file=mf)
print("...done.")
print("Wrote predictions to: {0}".format(filename))
print("Wrote some simple metadata about how the predictions were"
" generated to: {0}".format(meta_filename))