Skip to content

Commit

Permalink
somewhat stable version of the idea
Browse files Browse the repository at this point in the history
  • Loading branch information
andyhuang-kumo committed Aug 28, 2024
1 parent a181f45 commit b86641f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 73 deletions.
88 changes: 34 additions & 54 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from torch_geometric.utils import index_to_mask
from torch_geometric.utils.map import map_index


PRETRAIN_EPOCH = 1

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
LINK_PREDICTION_METRIC = "link_prediction_map"
Expand Down Expand Up @@ -72,8 +72,7 @@
parser.add_argument("--result_path", type=str, default="result")
args = parser.parse_args()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_num_threads(1)
seed_everything(args.seed)
Expand Down Expand Up @@ -153,23 +152,24 @@
train_search_space = {
"batch_size": [128, 256],
"base_lr": [0.0005, 0.01],
"gamma_rate": [0.9, 1.0],
"gamma_rate": [0.8, 1.0],
}
model_cls = RHSTransformer
elif args.model in ["rerank_transformer"]:
model_search_space = {
"encoder_channels": [64, 128, 256],
"encoder_layers": [2, 4, 8],
"channels": [64],
"embedding_dim": [64],
"channels": [64,128,256],
"norm": ["layer_norm", "batch_norm"],
"dropout": [0.0],
"rank_topk": [200],
"dropout": [0.0,0.1,0.2],
"t_encoding_type": ["fuse", "absolute"],
"rank_topk": [100,150,200],
"num_tr_layers": [1,2,3],
}
train_search_space = {
"batch_size": [128, 256, 512],
"batch_size": [256, 512],
"base_lr": [0.0005, 0.01],
"gamma_rate": [0.9, 1.0],
"gamma_rate": [0.8,1.0],
}
model_cls = ReRankTransformer

Expand Down Expand Up @@ -220,23 +220,22 @@ def train(
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
elif args.model in ["rerank_transformer"]:
gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(gnn_logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)

#* approach with map_index
label_index, mask = map_index(topk_idx.view(-1), dst_index)
true_label = torch.zeros(topk_idx.shape).to(tr_logits.device)
true_label[mask.view(true_label.shape)] = 1.0

# #* empty label rows
nonzero_mask = torch.any(true_label, dim=1)
tr_logits = tr_logits[nonzero_mask]
true_label = true_label[nonzero_mask]

#* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1]))
loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float())
gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col)
if (epoch <= PRETRAIN_EPOCH):
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(gnn_logits, edge_label_index)
else:
#* approach with map_index
label_index, mask = map_index(topk_idx.view(-1), dst_index)
true_label = torch.zeros(topk_idx.shape).to(tr_logits.device)
true_label[mask.view(true_label.shape)] = 1.0

# # #* empty label rows
# nonzero_mask = torch.any(true_label, dim=1)
# tr_logits = tr_logits[nonzero_mask]
# true_label = true_label[nonzero_mask]
loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float())

loss.backward()

Expand All @@ -259,7 +258,7 @@ def train(


@torch.no_grad()
def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
def test(model: torch.nn.Module, loader: NeighborLoader, stage: str, epoch:int,) -> float:
model.eval()

pred_list: List[Tensor] = []
Expand Down Expand Up @@ -290,32 +289,13 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table,
task.dst_entity_table,
task.dst_entity_col)

_, pred_idx = torch.topk(torch.sigmoid(tr_logits).detach(), k=task.eval_k, dim=1)
pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx]

# _, pred_mini = torch.topk(torch.sigmoid(gnn_logits.detach()), k=task.eval_k, dim=1)
# gnn_out = pred_mini[0]
# sort_out, _ = torch.sort(gnn_out)
# gnn_out = sort_out

# tr_out = pred_mini[0]
# sort_out, _ = torch.sort(tr_out)
# tr_out = sort_out

# assert torch.equal(gnn_out, tr_out)


#! to remove
# scores = torch.zeros(batch_size, task.num_dst_nodes,
# device=tr_logits.device)
# tr_logits = tr_logits.detach()
# scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach()))

# for i in range(scores.shape[0]):
# scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i])
# scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach()))
# scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten())
if (epoch <= PRETRAIN_EPOCH):
scores = torch.sigmoid(gnn_logits.detach())
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
else:
_, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=task.eval_k, dim=1)
pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx]

else:
raise ValueError(f"Unsupported model type: {args.model}.")
Expand Down Expand Up @@ -389,11 +369,11 @@ def train_and_eval_with_cfg(
train_loss = train(model, optimizer, loader_dict["train"],
train_sparse_tensor, epoch)
optimizer.zero_grad()
val_metric = test(model, loader_dict["val"], "val")
val_metric = test(model, loader_dict["val"], "val", epoch)

if val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test(model, loader_dict["test"], "test")
best_test_metric = test(model, loader_dict["test"], "test", epoch)

lr_scheduler.step()
print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}")
Expand Down
40 changes: 21 additions & 19 deletions hybridgnn/nn/models/rerank_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,27 @@ class ReRankTransformer(torch.nn.Module):
col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats
num_nodes (int): number of nodes,
num_layers (int): number of mp layers,
channels (int): input dimension,
embedding_dim (int): embedding dimension size,
channels (int): input dimension and embedding dimension
aggr (str): aggregation type,
norm (norm): normalization type,
dropout (float): dropout rate for the transformer float,
heads (int): number of attention heads,
rank_topk (int): how many top results of gnn would be reranked,"""
rank_topk (int): how many top results of gnn would be reranked,
num_tr_layers (int): number of transformer layers,"""
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_nodes: int,
num_layers: int,
channels: int,
embedding_dim: int,
aggr: str = 'sum',
norm: str = 'layer_norm',
dropout: float = 0.2,
heads: int = 1,
rank_topk: int = 100,
t_encoding_type: str = "absolute",
num_tr_layers: int = 1,
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -70,6 +71,7 @@ def __init__(
if "time" in data[node_type]
],
channels=channels,
encoding_type=t_encoding_type,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
Expand All @@ -84,23 +86,24 @@ def __init__(
norm=norm,
num_layers=1,
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)
self.lhs_projector = torch.nn.Linear(channels, channels)
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim)
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.rhs_embedding = torch.nn.Embedding(num_nodes, channels)
self.lin_offset_idgnn = torch.nn.Linear(channels, 1)
self.lin_offset_embgnn = torch.nn.Linear(channels, 1)

self.rank_topk = rank_topk

self.tr_embed_size = channels * 2
self.tr_blocks = torch.nn.ModuleList([
MultiheadAttentionBlock(
channels=embedding_dim*2,
channels=self.tr_embed_size,
heads=heads,
layer_norm=True,
dropout=dropout,
) for _ in range(1)
) for _ in range(num_tr_layers)
])
# self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim)
self.tr_lin = torch.nn.Linear(embedding_dim*2,1)
self.tr_lin = torch.nn.Linear(self.tr_embed_size,1)

self.channels = channels

Expand Down Expand Up @@ -199,19 +202,18 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index
out_indices = topk_indices.clone()
# [batch_size, topk, embed_size]
seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size)
# rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index

query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index)
id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size)
id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size).to(rhs_idgnn_embed.device)
id_gnn_seq[mask] = rhs_idgnn_embed[query_rhs_idgnn_index]

logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool)
logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool).to(rhs_idgnn_embed.device)
logit_mask[mask] = True
seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size))

lhs_uniq_embed = lhs_embedding[:batch_size]

seq = seq.clone()
# seq = seq.clone()
seq = seq.view(batch_size,self.rank_topk,-1)

lhs_uniq_embed = lhs_uniq_embed.view(-1,1,embed_size)
Expand All @@ -222,16 +224,16 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index
seq = block(seq, seq) # [# nodes, topk, embed_size]

#! just get the logit directly from transformer
seq = seq.view(-1,embed_size*2)
seq = seq.reshape(-1,self.tr_embed_size)
tr_logits = self.tr_lin(seq) # [batch_size, embed_size]
tr_logits = tr_logits.view(batch_size,self.rank_topk)

# seq = seq.view(batch_size * self.rank_topk, embed_size)
# lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size)

# tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum(
# dim=-1).flatten()

tr_logits = tr_logits.view(batch_size,self.rank_topk)

return tr_logits, out_indices


Expand Down

0 comments on commit b86641f

Please sign in to comment.