-
Notifications
You must be signed in to change notification settings - Fork 0
/
debug_control.py
65 lines (48 loc) · 1.76 KB
/
debug_control.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import random
import numpy as np
from data import ToyDataset, RankingTransform, ControlTransform
from models import GCN, GraphMLP
import torch
from torch.nn.functional import cross_entropy
from torchmetrics import Accuracy
from torch_geometric.loader import DataLoader
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--linear", action="store_true")
parser.add_argument("--time_inv", action="store_true")
parser.add_argument("--control_type", default="null", type=str)
parser.add_argument("--control_edges", default="adj", type=str)
parser.add_argument("--control_metric", default="degree", type=str)
parser.add_argument("--control_k", default=1, type=int)
parser.add_argument("--hidden_dim", default=8, type=int)
parser.add_argument("--conv_depth", default=2, type=int)
parser.add_argument("--dropout", default=0.0, type=float)
args = parser.parse_args()
random.seed(0)
np.random.seed(0)
torch.random.manual_seed(0)
torch.set_printoptions(linewidth=320)
if args.control_type != "null":
transform = ControlTransform(
args.control_edges, args.control_metric, args.control_k
)
else:
transform = None
pre_transform = RankingTransform()
dataset = ToyDataset(transform, pre_transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
model = GCN(
input_dim=dataset[0].x.shape[1],
output_dim=2,
hidden_dim=args.hidden_dim,
conv_depth=args.conv_depth,
dropout_rate=args.dropout,
linear=args.linear,
time_inv=args.time_inv,
control_type=args.control_type,
)
for batch in dataloader:
output = model(batch)
if __name__ == "__main__":
main()