-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
116 lines (94 loc) · 4.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
###############################################################
# @Author : Peizhao Li
# @Contact : [email protected]
# reference github link : https://github.com/brandeis-machine-learning/FairAdj
#
# add function
# 1. dataloader
# 2. fairdrop
# 3. greedy
# 4. evaluation
#############################################################
import torch
import numpy as np
import random
import networkx as nx
from args import parse_args
from dataloader import load_dataset
from model.fairdrop import *
# from model.gae import GCNModelVAE, loss_function, preprocess_graph
from model.fairU import FairU, preprocess_graph
from model.greedy import greedy_pp
from evals import get_scores, result_print
import torch.nn.functional as F
def main(args):
G, adj, features, sensitive, train_true_edges_split, train_false_edges_split, test_true_edges, test_false_edges = load_dataset(args.dataset, 'data')
n_nodes, feat_dim = features.shape
features = torch.FloatTensor(features).to(args.device)
sensitive_save = sensitive.copy()
model = FairU(feat_dim, args, sensitive) #.to(args.device)
model = model.to(args.device)
print('start train' + '-'*50)
model.train()
for epoch in range(args.n_epochs):
# drop edges fairly
for fold in range(len(train_true_edges_split)):
train_true_edges, train_false_edges = train_true_edges_split[fold], train_false_edges_split[fold]
train_G = G.copy()
train_G.remove_edges_from(train_true_edges)
train_adj = nx.adjacency_matrix(train_G, nodelist=sorted(G.nodes()))
train_true_edges = np.concatenate([train_true_edges, np.ones(train_true_edges.shape[0]).reshape(-1, 1)], axis=1).astype(int)
train_false_edges = np.concatenate([train_false_edges, np.zeros(train_false_edges.shape[0]).reshape(-1, 1)], axis=1).astype(int)
train_edges = np.concatenate([train_true_edges, train_false_edges])
fairdrop = fairdropper(train_G, train_adj.copy(), sensitive, args.n_epochs, args.device)
fairdrop.build_drop_map()
if args.fairdrop:
if epoch % args.fairdrop_term == 0:
adj_norm, adj_label, pos_weight, norm = fairdrop.drop_fairly(epoch)
else:
adj, adj_norm, adj_label = preprocess_graph(train_adj)
adj_norm = adj_norm.to(args.device)
adj_label = adj_label.to(args.device)
pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
pos_weight = torch.Tensor([pos_weight]).to(args.device)
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
model.optimizer.zero_grad()
adj_preds, mu, logvar, link_preds, adv_preds = model(features, adj_norm, train_edges)
for_reconloss = [mu, logvar, n_nodes, norm, pos_weight]
loss = model.loss_function(adj_preds, adj_label, for_reconloss, link_preds, train_edges, adv_preds)
loss.backward()
cur_loss = loss.item()
model.optimizer.step()
if (epoch+1) % 50 == 0:
print(f"Epoch: [{epoch+1:d} / {args.n_epochs}]; Loss: {cur_loss:.3f};")
model.eval()
with torch.no_grad():
_, adj_norm, _ = preprocess_graph(adj)
adj_norm = adj_norm.to(args.device)
z = model(features, adj_norm, None, train=False)
hidden_emb = z.data.cpu().numpy()
preds = np.array(np.dot(hidden_emb, hidden_emb.T), dtype=np.float128)
thresh = np.median(preds)
if args.greedy:
test_true_edges = np.concatenate([test_true_edges, np.ones(test_true_edges.shape[0]).reshape(-1, 1)], axis=1).astype(int)
test_false_edges = np.concatenate([test_false_edges, np.zeros(test_false_edges.shape[0]).reshape(-1, 1)], axis=1).astype(int)
test_edges = np.concatenate([test_true_edges, test_false_edges])
pred_temp = [preds[i, j] for i , j, _ in test_edges]
new_preds = greedy_pp(G, sensitive, test_edges, pred_temp, thresh, args.greedy_change_pct)
scores = get_scores(test_true_edges, test_false_edges, new_preds, G, sensitive)
else:
scores = get_scores(test_true_edges, test_false_edges, preds, G, sensitive)
result_print(scores)
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if __name__ == "__main__":
args = parse_args()
print(args)
args.device = torch.device(args.device)
fix_seed(args.seed)
main(args)