-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
128 lines (102 loc) · 4.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import division
from __future__ import print_function
import time
import tensorflow as tf
from utils import *
from models import GCN, MLP, TAGCN
# Set random seed
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)
# Settings
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset', 'cora', 'Dataset string.') # 'cora', 'citeseer', 'pubmed'
flags.DEFINE_string('model', 'tagcn', 'Model string.') # 'gcn', 'gcn_cheby', 'dense'
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.')
flags.DEFINE_integer('hidden1', 16, 'Number of units in hidden layer 1.')
flags.DEFINE_float('dropout', 0.5, 'Dropout rate (1 - keep probability).')
flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.')
flags.DEFINE_integer('early_stopping', 45, 'Tolerance for early stopping (# of epochs).')
flags.DEFINE_integer('max_degree', 3, 'Maximum Chebyshev polynomial degree.')
# Load data
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset)
# Some preprocessing
# features = preprocess_features(features)
features = features.todense()
if FLAGS.model == 'gcn':
support = [preprocess_adj(adj)]
num_supports = 1
model_func = GCN
elif FLAGS.model == 'gcn_cheby':
support = chebyshev_polynomials(adj, FLAGS.max_degree)
num_supports = 1 + FLAGS.max_degree
model_func = GCN
elif FLAGS.model == 'dense':
support = [preprocess_adj(adj)] # Not used
num_supports = 1
model_func = MLP
elif FLAGS.model == 'tagcn':
path_weight_matrix = np.load("cora_path_weights_norm.dat")
support = path_weight_matrix.astype('float32')
num_supports = 1
model_func = TAGCN
else:
raise ValueError('Invalid argument for model: ' + str(FLAGS.model))
# Define placeholders
# placeholders = {
# 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
# 'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
# 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
# 'labels_mask': tf.placeholder(tf.int32),
# 'dropout': tf.placeholder_with_default(0., shape=()),
# 'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout
# }
placeholders = {
'support': tf.placeholder(tf.float32, [features.shape[0],features.shape[0], 2]),
'features': tf.placeholder(tf.float32, shape=features.shape),
'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
'labels_mask': tf.placeholder(tf.int32),
'dropout': tf.placeholder_with_default(0.5, shape=()),
'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout
}
# Create model
model = model_func(placeholders, input_dim=features.shape[1], logging=True)
# Initialize session
sess = tf.Session()
# writer = tf.summary.FileWriter('tmp/', sess.graph)
# Define model evaluation function
def evaluate(features, support, labels, mask, placeholders):
t_test = time.time()
feed_dict_val = construct_feed_dict(features, support, labels, mask, placeholders)
outs_val = sess.run([model.loss, model.accuracy], feed_dict=feed_dict_val)
return outs_val[0], outs_val[1], (time.time() - t_test)
# Init variables
sess.run(tf.global_variables_initializer())
cost_val = []
# Train model
for epoch in range(FLAGS.epochs):
t = time.time()
# Construct feed dictionary
feed_dict = construct_feed_dict(features, support, y_train, train_mask, placeholders)
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
# Training step
outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
# Validation
cost, acc, duration = evaluate(features, support, y_val, val_mask, placeholders)
cost_val.append(cost)
# Print results
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]),
"train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost),
"val_acc=", "{:.5f}".format(acc), "time=", "{:.5f}".format(time.time() - t))
if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]):
print("Early stopping...")
break
print("Optimization Finished!")
# Testing
test_cost, test_acc, test_duration = evaluate(features, support, y_test, test_mask, placeholders)
print("Test set results:", "cost=", "{:.5f}".format(test_cost),
"accuracy=", "{:.5f}".format(test_acc), "time=", "{:.5f}".format(test_duration))
# writer.close()
# model.save(sess=sess)