-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
147 lines (115 loc) · 5.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from param_parser import parse_args
from utils import *
import os
import torch
from dataset import load_dataset, split_edge, EdgeDataset
from torch_geometric.loader import DataLoader
from model import *
import torch.optim as optim
import torch.nn as nn
from torch_geometric.nn import Node2Vec
from learn import *
import math
from sentence_transformers import SentenceTransformer
import time
import numpy as np
from sklearn.manifold import TSNE
import warnings
# To ignore all user warnings
warnings.filterwarnings("ignore", category=UserWarning)
def run(data, loaders, model, classifier, optimizer, loss_fn, train_edge0, train_edge, args):
best_val_f1_macro = -math.inf
best_epoch = 0
for epoch in range(args.epochs):
train(data, model, classifier, loaders['train'], optimizer, loss_fn, train_edge0, train_edge, args)
if epoch % 5 == 0:
val_acc, bacc, f1, f1_macro, f1_micro, _ , _, _= eval(data, model, classifier, loaders['val'], train_edge, args)
if f1_macro > best_val_f1_macro:
# print('Epoch: {:03d}, Val Accuracy: {:.4f}, Balanced Accuracy: {:.4f}, F1: {:.4f}, F1 Macro: {:.4f}, F1 Micro: {:.4f}'\
# .format(epoch, val_acc, bacc, f1, f1_macro, f1_micro))
best_val_f1_macro = f1_macro
best_epoch = epoch
torch.save(model.state_dict(), f'./model/{args.dataset}/best_model.pt')
torch.save(classifier.state_dict(), f'./model/{args.dataset}/best_classifier.pt')
model.load_state_dict(torch.load(f'./model/{args.dataset}/best_model.pt'))
classifier.load_state_dict(torch.load(f'./model/{args.dataset}/best_classifier.pt'))
test_acc, bacc, f1, f1_macro, f1_micro, cm, y_true, y_pred = eval(data, model, classifier, loaders['test'], train_edge, args) #train_val_edge
return test_acc, bacc, f1, f1_macro, f1_micro
if __name__ == '__main__':
start = time.time()
args = parse_args()
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.path = os.getcwd()
seed_everything(args.seed)
encoder = SentenceTransformer('all-MiniLM-L6-v2', device = args.device)
data = load_dataset(args, encoder)
edge_idxs = split_edge(data, args.train_ratio, args.val_ratio)
res=[]
train_edge0 = data.edge_index[:, edge_idxs['train']]
train_val_edge = torch.cat([data.edge_index[:, edge_idxs['train']], data.edge_index[:, edge_idxs['val']]], dim = 1).to(args.device)
train_edge = process_edge(train_edge0).to(args.device)
train_edge, train_val_edge = train_edge.to(args.device), train_val_edge.to(args.device)
train_edge0 = train_edge0.to(args.device)
train_y = data.y[edge_idxs['train']]
test_y = data.y[edge_idxs['test']]
val_y = data.y[edge_idxs['val']]
# # # # topology reweight
num_classes = int(torch.unique(data.y).shape[0])
if args.method == 'tw' or args.mixup == 2:
args.reweight = cal_reweight(train_y, num_classes).to(args.device)
args.topo_reweight= ge_reweight(data.edge_index[:, edge_idxs['train']], train_y.view(-1), args).to(args.device)
if args.method == 'qw':
args.reweight = cal_reweight(train_y, num_classes)
args.reweight = args.reweight.to(args.device)
if args.method == 'tq':
args.reweight = cal_reweight(train_y, num_classes)
args.reweight = args.reweight.to(args.device)
args.topo_reweight= ge_reweight(data.edge_index[:, edge_idxs['train']], train_y.view(-1), args).to(args.device)
split_data = {key: EdgeDataset(edge_idxs[key]) for key in edge_idxs}
loaders = {}
for key in split_data:
if key == 'train':
shuffle = True
else:
shuffle = False
loaders[key] = DataLoader(split_data[key], batch_size = args.batch_size, \
shuffle = shuffle, collate_fn = split_data[key].collate_fn)
for i in range(args.runs):
data = data.to(args.device)
if args.model == 'GCN':
model = GCNEncoder(data.num_nodes, args.n_embed, args.n_hidden).to(args.device)
elif args.model == 'GAT':
model = GATEncoder(data.num_nodes, args.n_embed, args.n_hidden).to(args.device)
elif args.model == 'SAGE':
model = SAGEEncoder(data.num_nodes, args.n_embed, args.n_hidden).to(args.device)
elif args.model == 'Cheb':
model = ChebEncoder(data.num_nodes, args.n_embed, args.n_hidden).to(args.device)
else:
raise NotImplementedError
loss_fn = nn.CrossEntropyLoss(reduction = 'none')
classifier = Classifier(args.n_hidden, data.edge_attr.shape[1], torch.unique(data.y).shape[0], args.dropout).to(args.device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=args.lr)
test_acc, bacc, f1, f1_macro, f1_micro=run(data, loaders, model, classifier, optimizer, loss_fn, train_edge0, train_edge, args)
result = {} # Initialize an empty dictionary for this run's results
result['test_acc'] = test_acc
result['bacc'] = bacc
result['f1'] = f1
result['f1_macro'] = f1_macro
result['f1_micro'] = f1_micro
res.append(result)
data = data.to('cpu')
# metrics = ['test_acc', 'bacc', 'f1', 'f1_macro', 'f1_micro']
metrics = [ 'bacc', 'f1_macro']
means = {}
stds = {}
end = time.time()
for metric in metrics:
values = [r[metric] for r in res]
means[metric] = np.mean(values)
stds[metric] = np.std(values)
metrics_output = []
for metric in metrics:
metrics_output.append("{}: {:.3f} $\pm$ {:.3f}".format(metric, means[metric], stds[metric]))
output_string = f"Total time: {end-start:.2f} seconds," + ', '.join(metrics_output)
# Print the output to the console
print(output_string)