-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
191 lines (147 loc) · 5.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import numpy as np
import torch
import logging
import pytz
import random
import os
import yaml
import shutil
from datetime import datetime
from ogb.nodeproppred import Evaluator
from dgl import function as fn
CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"]
OGB_data = ["ogbn-arxiv", "ogbn-products"]
NonHom_data = ["pokec", "penn94"]
BGNN_data = ["house_class", "vk_class"]
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_training_config(config_path, model_name, dataset):
with open(config_path, "r") as conf:
full_config = yaml.load(conf, Loader=yaml.FullLoader)
dataset_specific_config = full_config["global"]
model_specific_config = full_config[dataset][model_name]
if model_specific_config is not None:
specific_config = dict(dataset_specific_config, **model_specific_config)
else:
specific_config = dataset_specific_config
specific_config["model_name"] = model_name
return specific_config
def check_writable(path, overwrite=True):
if not os.path.exists(path):
os.makedirs(path)
elif overwrite:
shutil.rmtree(path)
os.makedirs(path)
else:
pass
def check_readable(path):
if not os.path.exists(path):
raise ValueError(f"No such file or directory! {path}")
def timetz(*args):
tz = pytz.timezone("US/Pacific")
return datetime.now(tz).timetuple()
def get_logger(filename, console_log=False, log_level=logging.INFO):
tz = pytz.timezone("US/Pacific")
log_time = datetime.now(tz).strftime("%b%d_%H_%M_%S")
logger = logging.getLogger(__name__)
logger.propagate = False # avoid duplicate logging
logger.setLevel(log_level)
# Clean logger first to avoid duplicated handlers
for hdlr in logger.handlers[:]:
logger.removeHandler(hdlr)
file_handler = logging.FileHandler(filename)
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%b%d %H-%M-%S")
formatter.converter = timetz
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_log:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
def idx_split(idx, ratio, seed=0):
"""
randomly split idx into two portions with ratio% elements and (1 - ratio)% elements
"""
set_seed(seed)
n = len(idx)
cut = int(n * ratio)
idx_idx_shuffle = torch.randperm(n)
idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:]
idx1, idx2 = idx[idx1_idx], idx[idx2_idx]
# assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all())
return idx1, idx2
def graph_split(idx_train, idx_val, idx_test, rate, seed):
"""
Args:
The original setting was transductive. Full graph is observed, and idx_train takes up a small portion.
Split the graph by further divide idx_test into [idx_test_tran, idx_test_ind].
rate = idx_test_ind : idx_test (how much test to hide for the inductive evaluation)
Ex. Ogbn-products
loaded : train : val : test = 8 : 2 : 90, rate = 0.2
after split: train : val : test_tran : test_ind = 8 : 2 : 72 : 18
Return:
Indices start with 'obs_' correspond to the node indices within the observed subgraph,
where as indices start directly with 'idx_' correspond to the node indices in the original graph
"""
idx_test_ind, idx_test_tran = idx_split(idx_test, rate, seed)
idx_obs = torch.cat([idx_train, idx_val, idx_test_tran])
N1, N2 = idx_train.shape[0], idx_val.shape[0]
obs_idx_all = torch.arange(idx_obs.shape[0])
obs_idx_train = obs_idx_all[:N1]
obs_idx_val = obs_idx_all[N1 : N1 + N2]
obs_idx_test = obs_idx_all[N1 + N2 :]
return obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind
def get_evaluator(dataset):
if dataset in CPF_data + NonHom_data + BGNN_data:
def evaluator(out, labels):
pred = out.argmax(1)
return pred.eq(labels).float().mean().item()
elif dataset in OGB_data:
ogb_evaluator = Evaluator(dataset)
def evaluator(out, labels):
pred = out.argmax(1, keepdim=True)
input_dict = {"y_true": labels.unsqueeze(1), "y_pred": pred}
return ogb_evaluator.eval(input_dict)["acc"]
else:
raise ValueError("Unknown dataset")
return evaluator
def get_evaluator(dataset):
def evaluator(out, labels):
pred = out.argmax(1)
return pred.eq(labels).float().mean().item()
return evaluator
def compute_min_cut_loss(g, out):
out = out.to("cpu")
g = g.to("cpu")
S = out.exp()
A = g.adj().to_dense()
D = g.in_degrees().float().diag()
print(S.device, A.device, D.device)
min_cut = (
torch.matmul(torch.matmul(S.transpose(1, 0), A), S).trace()
/ torch.matmul(torch.matmul(S.transpose(1, 0), D), S).trace()
)
return min_cut.item()
def feature_prop(feats, g, k):
"""
Augment node feature by propagating the node features within k-hop neighborhood.
The propagation is done in the SGC fashion, i.e. hop by hop and symmetrically normalized by node degrees.
"""
assert feats.shape[0] == g.num_nodes()
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).unsqueeze(1)
# compute (D^-1/2 A D^-1/2)^k X
for _ in range(k):
feats = feats * norm
g.ndata["h"] = feats
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = g.ndata.pop("h")
feats = feats * norm
return feats