-
Notifications
You must be signed in to change notification settings - Fork 0
/
2D.py
143 lines (119 loc) · 5.08 KB
/
2D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# 目前估计能做 4 列 [100,100,100,10],再多内存就存不下 node了
# 一般来说,增加query size 能覆盖更多的网格点,但是我们现在在对多个列同时采样的时候是选取随机一行中的值,所以graph.y有可能覆盖不到所有的网格点
# 目前没有用到 batch
# 对于 test-2.csv,使用 100 query size,从学习效果看,似乎是够的(除了右下角的点误差有点大)
# 后续
# 使用 Q-error 评估测试集合
# 扩展到 3D,5D,10D
# 尝试对数似然损失函数,将 CDF 视作一个 PDF
# 尝试 Graph 与 AR 的条件概率 CDF 模型结合,一列一列处理,避免笛卡尔积过大
# 尝试“用一部份点的集合”替换掉 intervalization, 变成连续的版本,i.e.使用比 internalization 更少的点来拟合 margin CDF
import argparse
from dataset import *
from models import *
from preprocessing import *
from utils import *
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="1-input", help="model type")
parser.add_argument("--dataset", type=str, default="test-2", help="Dataset.")
parser.add_argument("--query-size", type=int, default=1000, help="query size")
parser.add_argument("--min-conditions", type=int, default=1, help="min num of query conditions")
parser.add_argument("--max-conditions", type=int, default=2, help="max num of query conditions")
parser.add_argument(
"--unique-train", type=bool, default=True, help="whether make train set unique."
)
parser.add_argument(
"--boundary", type=bool, default=False, help="whether add boundary point to train set."
)
parser.add_argument(
"--channels", type=str, default="2,16,1", help="Comma-separated list of channels."
)
parser.add_argument("--num_layers", type=int, default=3, help="Number of hidden layers.")
parser.add_argument("--epochs", type=int, default=3000, help="Number of train epochs.")
parser.add_argument("--bs", type=int, default=1000, help="Batch size.")
parser.add_argument("--loss", type=str, default="MSE", help="Loss.")
parser.add_argument("--opt", type=str, default="adam", help="Optimizer.")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--plot_labels",
type=bool,
default=False,
help="whether add labels (selectivity) in compare plot.",
)
try:
args = parser.parse_args()
except:
# args = parser.parse_args([])
args, unknown = parser.parse_known_args()
args.channels = [int(x) for x in args.channels.split(",")]
ModelName = "2D_GNN"
FilePath = (
f"{args.model}_{args.dataset}_{args.query_size}_({args.min_conditions}_{args.max_conditions})"
)
resultsPath = f"results/{ModelName}/{FilePath}"
make_directory(resultsPath)
print("\nBegin Loading Data ...")
table, original_table_columns, sorted_table_columns, max_decimal_places = load_and_process_dataset(
args.dataset, resultsPath
)
table_size = table.shape
print(f"{args.dataset}.csv")
print(f"Table shape: {table_size}")
print("Done.\n")
print("Begin Generating Queries ...")
rng = np.random.RandomState(42)
query_set = [generate_random_query(table, args, rng) for _ in tqdm(range(args.query_size))]
print("Done.\n")
print("Begin Intervalization ...")
column_intervals = column_intervalization(query_set, table_size, args)
column_interval_number = count_unique_vals_num(column_intervals)
print(f"{column_interval_number=}")
print("Done.\n")
print("Begin Building Graph and Model ...")
graph = setup_graph(args, query_set, column_intervals, column_interval_number, table_size)
# pos = [
# np.array(np.unravel_index(i, column_interval_number)).reshape(1, -1) + 1
# for i in range(graph.x.shape[0])
# ]
# pos = np.concatenate(pos, axis=0)
# graph.pos = torch.from_numpy(pos).float()
# Visualize_initial_Graph_2D(graph, column_interval_number)
model = BaseModel(args, resultsPath, graph, device)
graph = graph.to(device)
print("Done.\n")
print("Begin Model Training ...\n")
model.fit()
print("Done.\n")
print("Begin Model Prediction ...")
# model.load()
out = model.predict(graph).squeeze(dim=-1).detach().cpu()
print(f"\nGround Truth:\n{graph.y[graph.train_mask]}")
print(f"\nModel Output:\n{out[graph.train_mask]}")
err = (out[graph.train_mask] - graph.y[graph.train_mask]).pow(2).mean()
print(f"\nFinal MSE: {err}")
print("\nDone.\n")
Visualize_compare_Graph_2D(
graph,
out,
args,
resultsPath,
figsize=(30, 15),
to_undirected=True,
with_labels=False,
)
# print("Begin Generating Data ...")
# Table_Generated = m.generate_table_by_row(values, batch_size=10000)
# print("Done.\n")
# print("Summary of Q-error:")
# print(args)
# df = calculate_Q_error(Table_Generated, query_set)
# df.to_csv(f"{resultsPath}/Q_error.csv", index=True, header=False)
# print(df)
# print(f"\n Original table shape : {table_size}")
# print(f"Generated table shape : {Table_Generated.shape}")
# print("Begin Recovering Data ...")
# recovered_Table_Generated = recover_table_as_original(
# Table_Generated, original_table_columns, sorted_table_columns, max_decimal_places
# )
# recovered_Table_Generated.to_csv(f"{resultsPath}/generated_table.csv", index=False, header=False)
# print("Done.\n")