-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
43 lines (25 loc) · 1.87 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
def settings():
parser = argparse.ArgumentParser("GraphClassification")
parser.add_argument("--dataset_dir", type=str, default="data_plk/APKICFG",help="Where dataset is stored.")
parser.add_argument("--epochs", type=int, default=10, help="epochs")
parser.add_argument("--batch_size", type=int, default=1, help="batch_size")
parser.add_argument("--iterations", type=int, default=3, help="number of iterations of dynamic routing")
parser.add_argument("--seed", type=int, default=12345, help="Initial random seed")
parser.add_argument('-node_emb_size', "--node_embedding_size", default=8, type=int,help="Intended subgraph embedding size to be learnt")
parser.add_argument('-graph_emb_size', "--graph_embedding_size", default=8, type=int,help="Intended graph embedding size to be learnt")
parser.add_argument("--learning_rate", default=0.0005, type=float,help="Learning rate to optimize the loss function")
parser.add_argument("--decay_step", default=100000, type=float,help="Learning rate decay step")
parser.add_argument("--lambda_val", default=1, type=float,help="Lambda factor for margin loss")
parser.add_argument("--noise", default=0.3, type=float, help="dropout applied in input data")
parser.add_argument("--Attention", default=True, type=bool, help="If use Attention module")
parser.add_argument("--reg_scale", default=0.1, type=float, help="Regualar scale (reconstruction loss)")
parser.add_argument("--coordinate", default=False, type=bool,help="If use Location record")
parser.add_argument("--x_fold", type=int, default=10, help="build train_test_split_index for x_fold.")
return parser.parse_args()
def get_net_structure():
net_structure = {
'node_emb':[2], # num of channels in each layer of GCN
'graph_emb':[2] # num of capsules in graph embedding layer
}
return net_structure