forked from BUPT-GAMMA/GammaGL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsign_trainer.py
91 lines (76 loc) · 3.62 KB
/
sign_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import argparse
import tensorlayerx as tlx
from tensorlayerx.model import TrainOneStep, WithLoss
from gammagl.datasets.flickr import Flickr
from gammagl.models.sign import SignModel
import gammagl.transforms as T
class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
def forward(self, data, label):
logits = self._backbone(data['xs'])
# CrossEntropyLoss equal nll_loss(log_sigmod(h), target)
loss = self._loss_fn(logits[data['train_mask']], label[data['train_mask']])
return loss
def evaluate(net, xs, label, mask, metrics):
net.set_eval()
logits = net(xs)
metrics.update(logits[mask], label[mask])
acc = metrics.result()
metrics.reset()
return acc
def main(args):
transform = T.Compose([T.NormalizeFeatures(), T.SIGN(args.K)])
dataset = Flickr(args.dataset_path, transform=transform)
graph = dataset[0]
graph = graph.tensor() # transform may modify some data into Numpy.ndarray.
xs = [graph.x]
xs += [graph[f'x{i}']for i in range(1, args.K + 1)]
net = SignModel(K=args.K, in_feat=dataset.num_node_features,
hid_feat=args.hidden_dim, num_classes=dataset.num_classes,
drop=1 - args.keep_rate)
optimizer = tlx.optimizers.Adam(args.lr, weight_decay=args.l2_coef)
metrics = tlx.metrics.Accuracy()
train_weights = net.trainable_weights
loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
best_val_acc = 0
data = {"xs": xs,
"train_mask": graph.train_mask}
for epoch in range(args.n_epoch):
net.set_train()
train_loss = train_one_step(data, graph.y)
val_acc = evaluate(net, xs, graph.y, graph.val_mask, metrics)
print("Epoch [{:0>3d}] ".format(epoch + 1) \
+ " train loss: {:.4f}".format(train_loss.item()) \
+ " val acc: {:.4f}".format(val_acc))
# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
net.save_weights(args.best_model_path + "SIGN.npz")
net.load_weights(args.best_model_path + "SIGN.npz")
test_acc = evaluate(net, xs, graph.y, graph.test_mask, metrics)
print("Test acc: {:.4f}".format(test_acc))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.01, help="learnin rate")
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
parser.add_argument('--dataset', type=str, default='Flickr', help='Only flickr')
parser.add_argument("--hidden_dim", type=int, default=1024, help="dimention of hidden layers")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
parser.add_argument("--l2_coef", type=float, default=0., help="l2 loss coeficient")
parser.add_argument("--K", type=int, default=2)
parser.add_argument("--keep_rate", type=float, default=0.5, help="keep_rate = 1 - drop_rate")
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")
main(args)