From 829fee50a9f132c327e7f38f890dcaaacfe29e92 Mon Sep 17 00:00:00 2001 From: Shenyang Huang Date: Thu, 8 Aug 2024 22:06:10 +0000 Subject: [PATCH 01/22] adding initial implementation of rhs transformer to repo --- .gitignore | 1 + README.md | 3 + .../relbench_link_prediction_benchmark.py | 22 ++- hybridgnn/nn/models/__init__.py | 6 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 153 ++++++++++++++++++ hybridgnn/nn/models/transformer.py | 113 +++++++++++++ 6 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 hybridgnn/nn/models/hybrid_rhstransformer.py create mode 100644 hybridgnn/nn/models/transformer.py diff --git a/.gitignore b/.gitignore index ced7769..bbc0e94 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.xml venv/* *.out data/** +*.txt diff --git a/README.md b/README.md index e59c177..f0354b1 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn ``` @@ -31,4 +32,6 @@ pip install -e . # to run examples and benchmarks pip install -e '.[full]' + +pip install -U sentence-transformers ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 4a857fa..3a0f6ee 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,7 +30,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -43,7 +43,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]] if args.model == "idgnn": model_search_space = { @@ -127,6 +127,18 @@ "gamma_rate": [0.9, 0.95, 1.], } model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) +elif args.model in ["rhstransformer"]: + model_search_space = { + "channels": [64, 128, 256], + "embedding_dim": [64, 128, 256], + "norm": ["layer_norm", "batch_norm"] + } + train_search_space = { + "batch_size": [256, 512, 1024], + "base_lr": [0.001, 0.01], + "gamma_rate": [0.9, 0.95, 1.], + } + model_cls = Hybrid_RHSTransformer def train( @@ -164,7 +176,7 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - elif args.model in ["hybridgnn", "shallowrhsgnn"]: + elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -248,7 +260,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index c8e9ef2..f20d482 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,5 +2,9 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN +from .hybrid_rhstransformer import Hybrid_RHSTransformer -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN'] +__all__ = classes = [ + 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', + 'Hybrid_RHSTransformer' +] diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py new file mode 100644 index 0000000..56ae32e --- /dev/null +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -0,0 +1,153 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from hybridgnn.nn.models.transformer import RHSTransformer + + +class Hybrid_RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + pe: str = "abs", + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_transformer = RHSTransformer(in_channels=channels, + out_channels=channels, + hidden_channels=channels, + heads=1, dropout=0.2, + position_encoding=pe) + + self.channels = channels + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + self.rhs_transformer.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + + #! adding transformer here + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch) + + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + return embgnn_logits diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py new file mode 100644 index 0000000..dcab570 --- /dev/null +++ b/hybridgnn/nn/models/transformer.py @@ -0,0 +1,113 @@ +import torch +from torch import Tensor, nn +from torch_geometric.typing import EdgeType, NodeType +from torch.nested import nested_tensor + +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor +from torch_geometric.utils import cumsum, scatter +from torch_geometric.nn.encoding import PositionalEncoding + + +class RHSTransformer(torch.nn.Module): + r"""A module to attend to rhs embeddings with a transformer. + Args: + in_channels (int): The number of input channels of the RHS embedding. + out_channels (int): The number of output channels. + hidden_channels (int): The hidden channel dimension of the transformer. + heads (int): The number of attention heads for the transformer. + num_transformer_blocks (int): The number of transformer blocks. + dropout (float): dropout rate for the transformer + """ + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + heads: int = 1, + num_transformer_blocks: int = 1, + dropout: float = 0.0, + position_encoding: str = "abs", + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lin = torch.nn.Linear(in_channels, hidden_channels) + self.fc = torch.nn.Linear(hidden_channels, out_channels) + self.pe_type = position_encoding + self.pe = None + if (position_encoding == "abs"): + self.pe = PositionalEncoding(hidden_channels) + elif (position_encoding == "rope"): + # rotary pe for queries + self.q_pe = RotaryPositionalEmbeddings(hidden_channels) + # rotary pe for keys + self.k_pe = RotaryPositionalEmbeddings(hidden_channels) + + self.blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=hidden_channels, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(num_transformer_blocks) + ]) + + def reset_parameters(self): + for block in self.blocks: + block.reset_parameters() + self.lin.reset_parameters() + self.fc.reset_parameters() + + def forward(self, rhs_embed: Tensor, index: Tensor, + rhs_time: Tensor = None) -> Tensor: + r"""Returns the attended to rhs embeddings + """ + rhs_embed = self.lin(rhs_embed) + + if (self.pe_type == "abs"): + if (rhs_time is None): + rhs_embed = rhs_embed + self.pe( + torch.arange(rhs_embed.size(0), device=rhs_embed.device)) + else: + rhs_embed = rhs_embed + self.pe(rhs_time) + + x, mask = to_dense_batch(rhs_embed, index) + for block in self.blocks: + # apply the pe for both query and key + if (self.pe_type == "rope"): + x_q = self.q_pe(x, pos=rhs_time) + x_k = self.k_pe(x, pos=rhs_time) + else: + x_q = x + x_k = x + x = block(x_q, x_k) + x = x[mask] + x = x.view(-1, self.hidden_channels) + return self.fc(x) + + +class RotaryPositionalEmbeddings(torch.nn.Module): + def __init__(self, channels, base=10000): + super().__init__() + self.channels = channels + self.base = base + self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / + channels)) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + if (pos is None): + pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', pos, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos = emb.cos().to(x.device) + sin = emb.sin().to(x.device) + + x1, x2 = x[..., ::2], x[..., 1::2] + rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) + + return x * cos + rotated * sin From 12a83bd062432718765b3113234edcf9cd1feeff Mon Sep 17 00:00:00 2001 From: Shenyang Huang Date: Fri, 9 Aug 2024 18:54:43 +0000 Subject: [PATCH 02/22] adding rhs transformer to the benchmark script --- .../relbench_link_prediction_benchmark.py | 14 +++++++----- hybridgnn/nn/models/hybrid_rhstransformer.py | 22 +++++++++++++++---- hybridgnn/nn/models/transformer.py | 19 +++++----------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 3a0f6ee..470bfb7 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,16 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64, 128, 256], - "embedding_dim": [64, 128, 256], - "norm": ["layer_norm", "batch_norm"] + "channels": [64, 128], + "embedding_dim": [64, 128], + "norm": ["layer_norm"], + "dropout": [0.1, 0.2], + "pe": ["abs", "none"], } train_search_space = { "batch_size": [256, 512, 1024], - "base_lr": [0.001, 0.01], - "gamma_rate": [0.9, 0.95, 1.], + "base_lr": [0.001, 0.01, 0.0001], + "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer @@ -217,7 +219,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) - elif args.model in ["hybridgnn", "shallowrhsgnn"]: + elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index 56ae32e..f7c3d2a 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -17,7 +17,19 @@ class Hybrid_RHSTransformer(torch.nn.Module): - r"""Implementation of RHSTransformer model.""" + r"""Implementation of RHSTransformer model. + Args: + data (HeteroData): dataset + col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats + num_nodes (int): number of nodes, + num_layers (int): number of mp layers, + channels (int): input dimension, + embedding_dim (int): embedding dimension size, + aggr (str): aggregation type, + norm (norm): normalization type, + dropout (float): dropout rate for the transformer float, + heads (int): number of attention heads, + pe (str): type of positional encoding for the transformer,""" def __init__( self, data: HeteroData, @@ -28,6 +40,8 @@ def __init__( embedding_dim: int, aggr: str = 'sum', norm: str = 'layer_norm', + dropout: float = 0.2, + heads: int = 1, pe: str = "abs", ) -> None: super().__init__() @@ -62,15 +76,15 @@ def __init__( num_layers=1, ) self.lhs_projector = torch.nn.Linear(channels, embedding_dim) - self.id_awareness_emb = torch.nn.Embedding(1, channels) self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_transformer = RHSTransformer(in_channels=channels, out_channels=channels, hidden_channels=channels, - heads=1, dropout=0.2, + heads=heads, dropout=dropout, position_encoding=pe) self.channels = channels @@ -120,7 +134,7 @@ def forward( rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size - #! adding transformer here + # adding rhs transformer rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, lhs_idgnn_batch) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index dcab570..f8290db 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -18,6 +18,7 @@ class RHSTransformer(torch.nn.Module): heads (int): The number of attention heads for the transformer. num_transformer_blocks (int): The number of transformer blocks. dropout (float): dropout rate for the transformer + position_encoding (str): type of positional encoding, """ def __init__( self, @@ -40,11 +41,10 @@ def __init__( self.pe = None if (position_encoding == "abs"): self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding == "rope"): - # rotary pe for queries - self.q_pe = RotaryPositionalEmbeddings(hidden_channels) - # rotary pe for keys - self.k_pe = RotaryPositionalEmbeddings(hidden_channels) + elif (position_encoding == "none"): + self.pe = None + else: + raise NotImplementedError self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( @@ -76,14 +76,7 @@ def forward(self, rhs_embed: Tensor, index: Tensor, x, mask = to_dense_batch(rhs_embed, index) for block in self.blocks: - # apply the pe for both query and key - if (self.pe_type == "rope"): - x_q = self.q_pe(x, pos=rhs_time) - x_k = self.k_pe(x, pos=rhs_time) - else: - x_q = x - x_k = x - x = block(x_q, x_k) + x = block(x, x) x = x[mask] x = x.view(-1, self.hidden_channels) return self.fc(x) From a3a9779817326292497edbd2dc740500df326f95 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 18:55:04 +0000 Subject: [PATCH 03/22] rhs transformer upload --- README.md | 6 +- .../relbench_link_prediction_benchmark.py | 22 +++++-- hybridgnn/nn/models/hybrid_rhstransformer.py | 65 ++++++++++++++++++- hybridgnn/nn/models/transformer.py | 15 ++--- 4 files changed, 89 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index f0354b1..622e1a1 100644 --- a/README.md +++ b/README.md @@ -13,15 +13,15 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo- ```sh python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer -python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn +python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py) ```sh -python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn -python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn +python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 +python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4 ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 470bfb7..b326f3b 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -117,12 +117,12 @@ model_cls = IDGNN elif args.model in ["hybridgnn", "shallowrhsgnn"]: model_search_space = { - "channels": [64, 128, 256], - "embedding_dim": [64, 128, 256], + "channels": [64, 128], + "embedding_dim": [64, 128], "norm": ["layer_norm", "batch_norm"] } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [256], "base_lr": [0.001, 0.01], "gamma_rate": [0.9, 0.95, 1.], } @@ -136,7 +136,7 @@ "pe": ["abs", "none"], } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [128], "base_lr": [0.001, 0.01, 0.0001], "gamma_rate": [0.9, 1.0], } @@ -178,11 +178,16 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + elif args.model in ["hybridgnn", "shallowrhsgnn"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rhstransformer"]: + logits = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) loss.backward() optimizer.step() @@ -219,11 +224,16 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) - elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + elif args.model in ["rhstransformer"]: + out = model(batch, task.src_entity_table, + task.dst_entity_table, + task.dst_entity_col).detach() + scores = torch.sigmoid(out) else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index f7c3d2a..a16de0a 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -14,6 +14,7 @@ ) from hybridgnn.nn.models import HeteroGraphSAGE from hybridgnn.nn.models.transformer import RHSTransformer +from torch_scatter import scatter_max class Hybrid_RHSTransformer(torch.nn.Module): @@ -108,7 +109,12 @@ def forward( batch: HeteroData, entity_table: NodeType, dst_table: NodeType, + dst_entity_col: NodeType, ) -> Tensor: + # print ("time dict has the following keys") + # print (batch.time_dict.keys()) + # dict_keys(['drop_withdrawals', 'outcomes', 'outcome_analyses', 'eligibilities', 'sponsors_studies', 'facilities_studies', 'interventions_studies', 'studies', 'designs', 'reported_event_totals', 'conditions_studies']) + seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) @@ -134,9 +140,12 @@ def forward( rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + #! need custom code to work for specific datasets + # rhs_time = self.get_rhs_time_dict(batch.time_dict, batch.edge_index_dict, batch[entity_table].seed_time, batch, dst_entity_col, dst_table) + # adding rhs transformer rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, - lhs_idgnn_batch) + lhs_idgnn_batch, batch_size=batch_size) rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( @@ -165,3 +174,57 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits return embgnn_logits + + def get_rhs_time_dict( + self, + time_dict, + edge_index_dict, + seed_time, + batch_dict, + dst_entity_col, + dst_entity_table, + ): + # edge_index_dict keys + """ + dict_keys([('drop_withdrawals', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'drop_withdrawals'), + ('outcomes', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'outcomes'), + ('outcome_analyses', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'outcome_analyses'), + ('outcome_analyses', 'f2p_outcome_id', 'outcomes'), + ('outcomes', 'rev_f2p_outcome_id', 'outcome_analyses'), + ('eligibilities', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'eligibilities'), + ('sponsors_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'sponsors_studies'), + ('sponsors_studies', 'f2p_sponsor_id', 'sponsors'), + ('sponsors', 'rev_f2p_sponsor_id', 'sponsors_studies'), + ('facilities_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'facilities_studies'), + ('facilities_studies', 'f2p_facility_id', 'facilities'), + ('facilities', 'rev_f2p_facility_id', 'facilities_studies'), + ('interventions_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'interventions_studies'), + ('interventions_studies', 'f2p_intervention_id', 'interventions'), + ('interventions', 'rev_f2p_intervention_id', 'interventions_studies'), + ('designs', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'designs'), + ('reported_event_totals', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'reported_event_totals'), + ('conditions_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'conditions_studies'), + ('conditions_studies', 'f2p_condition_id', 'conditions'), + ('conditions', 'rev_f2p_condition_id', 'conditions_studies')]) + """ + #* what to put when transaction table is merged + edge_index = edge_index_dict['sponsors','f2p_sponsor_id', + 'sponsors_studies'] + rhs_time, _ = scatter_max( + time_dict['sponsors'][edge_index[0]], + edge_index[1]) + SECONDS_IN_A_DAY = 60 * 60 * 24 + NANOSECONDS_IN_A_DAY = 60 * 60 * 24 * 1000000000 + rhs_rel_time = seed_time[batch_dict[dst_entity_col]] - rhs_time + rhs_rel_time = rhs_rel_time / NANOSECONDS_IN_A_DAY + return rhs_rel_time diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index f8290db..b189c2f 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -1,4 +1,5 @@ import torch +import math from torch import Tensor, nn from torch_geometric.typing import EdgeType, NodeType from torch.nested import nested_tensor @@ -61,20 +62,16 @@ def reset_parameters(self): self.lin.reset_parameters() self.fc.reset_parameters() - def forward(self, rhs_embed: Tensor, index: Tensor, - rhs_time: Tensor = None) -> Tensor: + + def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: r"""Returns the attended to rhs embeddings """ rhs_embed = self.lin(rhs_embed) if (self.pe_type == "abs"): - if (rhs_time is None): - rhs_embed = rhs_embed + self.pe( - torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - else: - rhs_embed = rhs_embed + self.pe(rhs_time) - - x, mask = to_dense_batch(rhs_embed, index) + rhs_embed = rhs_embed + self.pe( + torch.arange(rhs_embed.size(0), device=rhs_embed.device)) + x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From 988781ba6c6d6e88a89dae59f6f38e8c78767008 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 19:20:11 +0000 Subject: [PATCH 04/22] updating tr --- hybridgnn/nn/models/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index b189c2f..5f88da6 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -71,7 +71,7 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: if (self.pe_type == "abs"): rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) + x, mask = to_dense_batch(rhs_embed, index, max_num_nodes=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From e5faa22afa74bce3c0e6697ec060a73e44ed8b5a Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 21:08:33 +0000 Subject: [PATCH 05/22] running code --- benchmark/relbench_link_prediction_benchmark.py | 7 ++++--- hybridgnn/nn/models/transformer.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index b326f3b..984079d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,15 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64, 128], - "embedding_dim": [64, 128], + "channels": [64], + "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": ["abs", "none"], + "num_neighbors": [args.num_neighbors], } train_search_space = { - "batch_size": [128], + "batch_size": [64], "base_lr": [0.001, 0.01, 0.0001], "gamma_rate": [0.9, 1.0], } diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 5f88da6..40a3e8c 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -71,7 +71,11 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: if (self.pe_type == "abs"): rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - x, mask = to_dense_batch(rhs_embed, index, max_num_nodes=batch_size) + + sorted_index, _ = torch.sort(index) + index = sorted_index + + x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From 5e904eae13c5bcbc9411106f57835671aacc3cb0 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 21:09:22 +0000 Subject: [PATCH 06/22] running code --- benchmark/relbench_link_prediction_benchmark.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 984079d..ddb626d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -134,7 +134,6 @@ "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": ["abs", "none"], - "num_neighbors": [args.num_neighbors], } train_search_space = { "batch_size": [64], From d6ff1698e29d26b298723a02640c989d5777b65a Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 20:47:16 +0000 Subject: [PATCH 07/22] adding transformer changes --- README.md | 1 + hybridgnn/nn/models/transformer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 622e1a1..d02d749 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10 python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 40a3e8c..502ee4e 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -72,8 +72,9 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - sorted_index, _ = torch.sort(index) - index = sorted_index + # #! if we sort the index, we need to sort the rhs_embed + # sorted_index, _ = torch.sort(index) + # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: From 4884b837b081c0f0fd363a93c6528851949511fc Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 22:19:00 +0000 Subject: [PATCH 08/22] permute the index, the rhs and then reverse it --- hybridgnn/nn/models/transformer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 502ee4e..d227d8b 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -63,7 +63,7 @@ def reset_parameters(self): self.fc.reset_parameters() - def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: + def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: r"""Returns the attended to rhs embeddings """ rhs_embed = self.lin(rhs_embed) @@ -73,7 +73,10 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: torch.arange(rhs_embed.size(0), device=rhs_embed.device)) # #! if we sort the index, we need to sort the rhs_embed - # sorted_index, _ = torch.sort(index) + sorted_index, sorted_idx = torch.sort(index, stable=True) + index = index[sorted_idx] + rhs_embed = rhs_embed[sorted_idx] + reverse = self.inverse_permutation(sorted_idx) # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) @@ -81,8 +84,16 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: x = block(x, x) x = x[mask] x = x.view(-1, self.hidden_channels) + x = x[reverse] + # x = x.gather(1, sorted_idx.argsort(1)) + return self.fc(x) + def inverse_permutation(self,perm): + inv = torch.empty_like(perm) + inv[perm] = torch.arange(perm.size(0), device=perm.device) + return inv + class RotaryPositionalEmbeddings(torch.nn.Module): def __init__(self, channels, base=10000): From 08041cb73905f888ac7a2c58beabc272cc31b6c2 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 22:21:15 +0000 Subject: [PATCH 09/22] removing none to replace with None --- .../relbench_link_prediction_benchmark.py | 2 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 33 ------------------- hybridgnn/nn/models/transformer.py | 2 +- 3 files changed, 2 insertions(+), 35 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index ddb626d..2586e86 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -133,7 +133,7 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], - "pe": ["abs", "none"], + "pe": ["abs", None], } train_search_space = { "batch_size": [64], diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index a16de0a..a22202d 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -184,39 +184,6 @@ def get_rhs_time_dict( dst_entity_col, dst_entity_table, ): - # edge_index_dict keys - """ - dict_keys([('drop_withdrawals', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'drop_withdrawals'), - ('outcomes', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'outcomes'), - ('outcome_analyses', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'outcome_analyses'), - ('outcome_analyses', 'f2p_outcome_id', 'outcomes'), - ('outcomes', 'rev_f2p_outcome_id', 'outcome_analyses'), - ('eligibilities', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'eligibilities'), - ('sponsors_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'sponsors_studies'), - ('sponsors_studies', 'f2p_sponsor_id', 'sponsors'), - ('sponsors', 'rev_f2p_sponsor_id', 'sponsors_studies'), - ('facilities_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'facilities_studies'), - ('facilities_studies', 'f2p_facility_id', 'facilities'), - ('facilities', 'rev_f2p_facility_id', 'facilities_studies'), - ('interventions_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'interventions_studies'), - ('interventions_studies', 'f2p_intervention_id', 'interventions'), - ('interventions', 'rev_f2p_intervention_id', 'interventions_studies'), - ('designs', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'designs'), - ('reported_event_totals', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'reported_event_totals'), - ('conditions_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'conditions_studies'), - ('conditions_studies', 'f2p_condition_id', 'conditions'), - ('conditions', 'rev_f2p_condition_id', 'conditions_studies')]) - """ #* what to put when transaction table is merged edge_index = edge_index_dict['sponsors','f2p_sponsor_id', 'sponsors_studies'] diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index d227d8b..f7e1f14 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -42,7 +42,7 @@ def __init__( self.pe = None if (position_encoding == "abs"): self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding == "none"): + elif (position_encoding is None): self.pe = None else: raise NotImplementedError From 29c70ad25c799ddf239f65314846a8f2484d6039 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 14 Aug 2024 19:41:10 +0000 Subject: [PATCH 10/22] add time fuse encoder to extract time pe --- .../relbench_link_prediction_benchmark.py | 6 +++--- hybridgnn/nn/encoder.py | 20 +++++++++++++++++-- hybridgnn/nn/models/transformer.py | 3 --- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 2586e86..080ccf5 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -133,11 +133,11 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], - "pe": ["abs", None], + "pe": [None], } train_search_space = { - "batch_size": [64], - "base_lr": [0.001, 0.01, 0.0001], + "batch_size": [512], + "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 022b30a..c56ad00 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -19,6 +19,9 @@ torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), } SECONDS_IN_A_DAY = 60 * 60 * 24 +SECONDS_IN_A_WEEK = 7 * 60 * 60 * 24 +SECONDS_IN_A_HOUR = 60 * 60 +SECONDS_IN_A_MINUTE = 60 class HeteroEncoder(torch.nn.Module): @@ -109,11 +112,15 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None: for node_type in node_types }) + time_dim = 3 # hour, day, week + self.time_fuser = torch.nn.Linear(time_dim, channels) + def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() + self.time_fuser.reset_parameters() def forward( self, @@ -125,9 +132,18 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - rel_time = rel_time / SECONDS_IN_A_DAY - x = self.encoder_dict[node_type](rel_time) + # rel_day = rel_time / SECONDS_IN_A_DAY + # x = self.encoder_dict[node_type](rel_day) + # x = self.encoder_dict[node_type](rel_hour) + rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) + rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) + rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) + time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() + + #! might need to normalize hour, day, week into the same scale + time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) + x = self.time_fuser(time_embed) x = self.lin_dict[node_type](x) out_dict[node_type] = x diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index f7e1f14..a9cf2ac 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -77,7 +77,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: index = index[sorted_idx] rhs_embed = rhs_embed[sorted_idx] reverse = self.inverse_permutation(sorted_idx) - # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: @@ -85,8 +84,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: x = x[mask] x = x.view(-1, self.hidden_channels) x = x[reverse] - # x = x.gather(1, sorted_idx.argsort(1)) - return self.fc(x) def inverse_permutation(self,perm): From f58b52844b5bdabb5b1b995cc1e5beb1d13c9828 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 14 Aug 2024 19:42:48 +0000 Subject: [PATCH 11/22] update hyperparameter options --- benchmark/relbench_link_prediction_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 080ccf5..5331f7c 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,14 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64], - "embedding_dim": [64], + "channels": [64, 128], + "embedding_dim": [64, 128], "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": [None], } train_search_space = { - "batch_size": [512], + "batch_size": [128, 256, 512], "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } From 8f4966cec374bb76797edfc47160d1e0036d2a77 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 19 Aug 2024 22:54:33 +0000 Subject: [PATCH 12/22] adding rerank_transformer --- .../relbench_link_prediction_benchmark.py | 35 ++- hybridgnn/nn/models/__init__.py | 3 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 4 - hybridgnn/nn/models/rerank_transformer.py | 216 ++++++++++++++++++ 4 files changed, 248 insertions(+), 10 deletions(-) create mode 100644 hybridgnn/nn/models/rerank_transformer.py diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 5331f7c..715c77f 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,7 +30,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -43,7 +43,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer]] if args.model == "idgnn": model_search_space = { @@ -141,7 +141,20 @@ "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer - +elif args.model in ["rerank_transformer"]: + model_search_space = { + "channels": [64], + "embedding_dim": [64], + "norm": ["layer_norm"], + "dropout": [0.0, 0.1, 0.2], + "rank_topk": [25,50,100] + } + train_search_space = { + "batch_size": [128, 256, 512], + "base_lr": [0.0005, 0.01], + "gamma_rate": [0.9, 1.0], + } + model_cls = ReRankTransformer def train( model: torch.nn.Module, @@ -188,6 +201,13 @@ def train( edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rerank_transformer"]: + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) + loss += sparse_cross_entropy(tr_logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) + loss.backward() optimizer.step() @@ -234,6 +254,11 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_table, task.dst_entity_col).detach() scores = torch.sigmoid(out) + elif args.model in ["rerank_transformer"]: + _, out, _ = model(batch, task.src_entity_table, + task.dst_entity_table, + task.dst_entity_col) + scores = torch.sigmoid(out.detach()) else: raise ValueError(f"Unsupported model type: {args.model}.") @@ -272,7 +297,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index f20d482..5d9f3cf 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -3,8 +3,9 @@ from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN from .hybrid_rhstransformer import Hybrid_RHSTransformer +from .rerank_transformer import ReRankTransformer __all__ = classes = [ 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', - 'Hybrid_RHSTransformer' + 'Hybrid_RHSTransformer', 'ReRankTransformer' ] diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index a22202d..e373ef4 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -111,10 +111,6 @@ def forward( dst_table: NodeType, dst_entity_col: NodeType, ) -> Tensor: - # print ("time dict has the following keys") - # print (batch.time_dict.keys()) - # dict_keys(['drop_withdrawals', 'outcomes', 'outcome_analyses', 'eligibilities', 'sponsors_studies', 'facilities_studies', 'interventions_studies', 'studies', 'designs', 'reported_event_totals', 'conditions_studies']) - seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py new file mode 100644 index 0000000..e2760fc --- /dev/null +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -0,0 +1,216 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from torch_scatter import scatter_max +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch + + + +class ReRankTransformer(torch.nn.Module): + r"""Implementation of ReRank Transformer model. + Args: + data (HeteroData): dataset + col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats + num_nodes (int): number of nodes, + num_layers (int): number of mp layers, + channels (int): input dimension, + embedding_dim (int): embedding dimension size, + aggr (str): aggregation type, + norm (norm): normalization type, + dropout (float): dropout rate for the transformer float, + heads (int): number of attention heads, + rank_topk (int): how many top results of gnn would be reranked,""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + dropout: float = 0.2, + heads: int = 1, + rank_topk: int = 100, + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + + self.rank_topk = rank_topk + self.tr_blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=embedding_dim, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(1) + ]) + self.channels = channels + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + for block in self.tr_blocks: + block.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + dst_entity_col: NodeType, + ) -> Tensor: + + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + + + + """ + detach the variable + """ + all_rhs_embed = rhs_embedding.weight.detach().clone() #only shallow rhs embeds + assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding.detach().clone() # apply the idGNN embeddings here + + + # all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds + # #! this causes error when the channel size and hidden size is different + # assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here + + + transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) + # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) + + + + # return embgnn_logits, transformer_logits + return embgnn_logits, transformer_logits, topk_index + + + def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): + """ + reranks the gnn logits based on the provided gnn embeddings. + rhs_gnn_embedding:[# rhs nodes, embed_dim] + """ + topk = self.rank_topk + _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) + embed_size = rhs_gnn_embedding.shape[1] + + # need input batch of size [# nodes, topk, embed_size] + top_embed = torch.stack([rhs_gnn_embedding[topk_index[idx]] for idx in range(topk_index.shape[0])]) + for block in self.tr_blocks: + tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] + + #! for top 50 prediction + # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) + for idx in range(topk_index.shape[0]): + gnn_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + return gnn_logits, topk_index \ No newline at end of file From 90a8817de472aa773a3f2616ce137314daa386f5 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 19 Aug 2024 23:42:22 +0000 Subject: [PATCH 13/22] setting zeros for not used logits in rerank transformer --- benchmark/relbench_link_prediction_benchmark.py | 13 ++++++++++--- hybridgnn/nn/models/rerank_transformer.py | 8 +++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 715c77f..8bf4ab9 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -147,7 +147,7 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.0, 0.1, 0.2], - "rank_topk": [25,50,100] + "rank_topk": [25,50,100, 200] } train_search_space = { "batch_size": [128, 256, 512], @@ -255,10 +255,17 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_col).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: - _, out, _ = model(batch, task.src_entity_table, + gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - scores = torch.sigmoid(out.detach()) + gnn_logits, tr_logits, topk_index = gnn_logits.detach(), tr_logits.detach(), topk_index.detach() + for idx in range(topk_index.shape[0]): + gnn_logits[idx][topk_index[idx]] = tr_logits[idx][topk_index[idx]] + + scores = torch.sigmoid(gnn_logits) + #scores = torch.sigmoid(out.detach()) + + else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index e2760fc..a62d5bf 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -189,9 +189,6 @@ def forward( transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) - - - # return embgnn_logits, transformer_logits return embgnn_logits, transformer_logits, topk_index @@ -210,7 +207,8 @@ def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] #! for top 50 prediction + out_logits = torch.zeros(gnn_logits.shape).to(gnn_logits.device) # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) for idx in range(topk_index.shape[0]): - gnn_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - return gnn_logits, topk_index \ No newline at end of file + out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + return out_logits, topk_index \ No newline at end of file From 1eb769cdd4cf75a3ad732209ee385cdcc6516aee Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 21 Aug 2024 18:44:50 +0000 Subject: [PATCH 14/22] adding reranker transformer --- .../relbench_link_prediction_benchmark.py | 27 ++++++--- hybridgnn/nn/models/rerank_transformer.py | 59 ++++++++++++++----- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 8bf4ab9..3b5756d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -32,6 +32,9 @@ from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils.map import map_index + + TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -146,8 +149,8 @@ "channels": [64], "embedding_dim": [64], "norm": ["layer_norm"], - "dropout": [0.0, 0.1, 0.2], - "rank_topk": [25,50,100, 200] + "dropout": [0.1, 0.2], + "rank_topk": [100] } train_search_space = { "batch_size": [128, 256, 512], @@ -205,6 +208,19 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) + + #! continue here to debug for map_index to only get label for the topk that transformer learns + """ + # batch_size = batch[task.src_entity_table].batch_size + # target = torch.isin( + # batch[task.dst_entity_table].batch + + # batch_size * batch[task.dst_entity_table].n_id, + # src_batch + batch_size * dst_index, + # ).float() + # print (target.shape) + # quit() + # topk_labels = map_index(edge_label_index, topk_idx) + """ loss += sparse_cross_entropy(tr_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) @@ -258,12 +274,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - gnn_logits, tr_logits, topk_index = gnn_logits.detach(), tr_logits.detach(), topk_index.detach() - for idx in range(topk_index.shape[0]): - gnn_logits[idx][topk_index[idx]] = tr_logits[idx][topk_index[idx]] - - scores = torch.sigmoid(gnn_logits) - #scores = torch.sigmoid(out.detach()) + scores = torch.sigmoid(tr_logits.detach()) else: diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index a62d5bf..e1fb073 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -16,6 +16,7 @@ from torch_scatter import scatter_max from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock from torch_geometric.utils import to_dense_batch +from torch_geometric.utils.map import map_index @@ -93,6 +94,8 @@ def __init__( dropout=dropout, ) for _ in range(1) ]) + # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + self.channels = channels self.reset_parameters() @@ -109,6 +112,7 @@ def reset_parameters(self) -> None: self.lhs_projector.reset_parameters() for block in self.tr_blocks: block.reset_parameters() + # self.tr_lin.reset_parameters() def forward( self, @@ -170,27 +174,52 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits - - - """ - detach the variable - """ - all_rhs_embed = rhs_embedding.weight.detach().clone() #only shallow rhs embeds + #! let's do end to end transformer here + all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" - all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding.detach().clone() # apply the idGNN embeddings here - - # all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds - # #! this causes error when the channel size and hidden size is different - # assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + #* rhs_gnn_embedding is significantly smaller than rhs_embed and we can't use inplace operation during backprop + #* -----> this is not global, can't replace like this + copy_tensor = torch.zeros(all_rhs_embed.shape).to(all_rhs_embed.device) + copy_tensor[rhs_idgnn_index] = rhs_gnn_embedding + final_rhs_embed = all_rhs_embed + copy_tensor # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here + # transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), final_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) + transformer_logits, topk_index = self.rerank(embgnn_logits, final_rhs_embed, lhs_idgnn_batch, lhs_embedding_projected[lhs_idgnn_batch]) - transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) - # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) return embgnn_logits, transformer_logits, topk_index + #* adding lhs embedding code not working yet + # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): + # """ + # reranks the gnn logits based on the provided gnn embeddings. + # rhs_gnn_embedding:[# rhs nodes, embed_dim] + # """ + # topk = self.rank_topk + # _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) + # embed_size = rhs_gnn_embedding.shape[1] + + # # need input batch of size [# nodes, topk, embed_size] + # #! concatenate the lhs embedding with rhs embedding + # top_embed = torch.stack([torch.cat((rhs_gnn_embedding[topk_index[idx]],lhs_embedding[idx].view(1,-1).expand(self.rank_topk,-1)), dim=1) for idx in range(topk_index.shape[0])]) + # tr_embed = top_embed + # for block in self.tr_blocks: + # tr_embed = block(tr_embed, tr_embed) # [# nodes, topk, embed_size] + + # tr_embed = tr_embed.view(-1,embed_size*2) + # tr_embed = self.tr_lin(tr_embed) + # tr_embed = tr_embed.view(-1,self.rank_topk,embed_size) + + + # #! for top k prediction + # out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) + # # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) + # for idx in range(topk_index.shape[0]): + # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + # return out_logits, topk_index + def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): """ @@ -206,8 +235,8 @@ def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): for block in self.tr_blocks: tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] - #! for top 50 prediction - out_logits = torch.zeros(gnn_logits.shape).to(gnn_logits.device) + #! for top k prediction + out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) for idx in range(topk_index.shape[0]): out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() From 75445cee5618d35974b79b5c32b9e298b747eab3 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 21 Aug 2024 23:23:38 +0000 Subject: [PATCH 15/22] updating RHS transformer code --- .../relbench_link_prediction_benchmark.py | 27 ++++--- hybridgnn/nn/encoder.py | 42 +++++++---- ...id_rhstransformer.py => RHSTransformer.py} | 75 +++++-------------- hybridgnn/nn/models/__init__.py | 4 +- hybridgnn/nn/models/transformer.py | 39 +--------- 5 files changed, 66 insertions(+), 121 deletions(-) rename hybridgnn/nn/models/{hybrid_rhstransformer.py => RHSTransformer.py} (73%) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index a375492..6ec54be 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -1,3 +1,7 @@ +""" +$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10 +""" + import argparse import json import os @@ -30,7 +34,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding from torch_geometric.utils.map import map_index @@ -105,7 +109,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer]] if args.model == "idgnn": model_search_space = { @@ -136,23 +140,27 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { + "encoder_channels": [64, 128], + "encoder_layers": [2, 4], "channels": [64, 128], "embedding_dim": [64, 128], - "norm": ["layer_norm"], + "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "pe": [None], + "t_encoding_type": ["fuse", "absolute"], } train_search_space = { - "batch_size": [128, 256, 512], + "batch_size": [128, 256], "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } - model_cls = Hybrid_RHSTransformer + model_cls = RHSTransformer elif args.model in ["rerank_transformer"]: model_search_space = { + "encoder_channels": [64, 128, 256], + "encoder_layers": [2, 4, 8], "channels": [64], "embedding_dim": [64], - "norm": ["layer_norm"], + "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], "rank_topk": [100] } @@ -204,7 +212,7 @@ def train( loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) elif args.model in ["rhstransformer"]: - logits = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) @@ -271,8 +279,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: scores = torch.sigmoid(out) elif args.model in ["rhstransformer"]: out = model(batch, task.src_entity_table, - task.dst_entity_table, - task.dst_entity_col).detach() + task.dst_entity_table).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 9887fc1..2a8c933 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -95,9 +95,16 @@ def forward( class HeteroTemporalEncoder(torch.nn.Module): - def __init__(self, node_types: List[NodeType], channels: int) -> None: + def __init__(self, node_types: List[NodeType], channels: int, + encoding_type: Optional[str] = "absolute",) -> None: + r""" + temporal encoder: + encoding_type (str, optional): which type of time encoding to use, options are ["absolute", "learnable", "fuse"] + """ super().__init__() + self.encoding_type = encoding_type # ["absolute", "fuse"] + self.encoder_dict = torch.nn.ModuleDict({ node_type: PositionalEncoding(channels) @@ -109,15 +116,19 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None: for node_type in node_types }) - time_dim = 3 # hour, day, week - self.time_fuser = torch.nn.Linear(time_dim, channels) + if (self.encoding_type == "fuse"): + time_dim = 3 # hour, day, week + self.time_fuser = torch.nn.Linear(time_dim, channels) def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() - self.time_fuser.reset_parameters() + if (self.encoding_type == "learnable"): + self.day_pe.reset_parameters() + elif (self.encoding_type == "fuse"): + self.time_fuser.reset_parameters() def forward( self, @@ -129,18 +140,17 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - - # rel_day = rel_time / SECONDS_IN_A_DAY - # x = self.encoder_dict[node_type](rel_day) - # x = self.encoder_dict[node_type](rel_hour) - rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) - rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) - rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) - time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() - - #! might need to normalize hour, day, week into the same scale - time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) - x = self.time_fuser(time_embed) + + if (self.encoding_type == "absolute"): + rel_time = rel_time / SECONDS_IN_A_DAY + x = self.encoder_dict[node_type](rel_time) + elif (self.encoding_type == "fuse"): + rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) + rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) + rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) + time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() + time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) #normalize hour, day, week into same scale + x = self.time_fuser(time_embed) x = self.lin_dict[node_type](x) out_dict[node_type] = x diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/RHSTransformer.py similarity index 73% rename from hybridgnn/nn/models/hybrid_rhstransformer.py rename to hybridgnn/nn/models/RHSTransformer.py index e373ef4..585d337 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/RHSTransformer.py @@ -1,8 +1,9 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Type import torch from torch import Tensor from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet from torch_geometric.data import HeteroData from torch_geometric.nn import MLP from torch_geometric.typing import NodeType @@ -13,24 +14,11 @@ HeteroTemporalEncoder, ) from hybridgnn.nn.models import HeteroGraphSAGE -from hybridgnn.nn.models.transformer import RHSTransformer -from torch_scatter import scatter_max - - -class Hybrid_RHSTransformer(torch.nn.Module): - r"""Implementation of RHSTransformer model. - Args: - data (HeteroData): dataset - col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats - num_nodes (int): number of nodes, - num_layers (int): number of mp layers, - channels (int): input dimension, - embedding_dim (int): embedding dimension size, - aggr (str): aggregation type, - norm (norm): normalization type, - dropout (float): dropout rate for the transformer float, - heads (int): number of attention heads, - pe (str): type of positional encoding for the transformer,""" +from hybridgnn.nn.models.transformer import Transformer + + +class RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" def __init__( self, data: HeteroData, @@ -43,7 +31,9 @@ def __init__( norm: str = 'layer_norm', dropout: float = 0.2, heads: int = 1, - pe: str = "abs", + t_encoding_type: str = "absolute", + torch_frame_model_cls: Type[torch.nn.Module] = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() @@ -55,6 +45,8 @@ def __init__( }, node_to_col_stats=col_stats_dict, stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_cls=torch_frame_model_cls, + torch_frame_model_kwargs=torch_frame_model_kwargs, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=[ @@ -62,6 +54,7 @@ def __init__( if "time" in data[node_type] ], channels=channels, + encoding_type=t_encoding_type, ) self.gnn = HeteroGraphSAGE( node_types=data.node_types, @@ -81,14 +74,12 @@ def __init__( self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.channels = channels - self.rhs_transformer = RHSTransformer(in_channels=channels, + self.rhs_transformer = Transformer(in_channels=channels, out_channels=channels, hidden_channels=channels, - heads=heads, dropout=dropout, - position_encoding=pe) - - self.channels = channels + heads=heads, dropout=dropout) self.reset_parameters() @@ -109,7 +100,6 @@ def forward( batch: HeteroData, entity_table: NodeType, dst_table: NodeType, - dst_entity_col: NodeType, ) -> Tensor: seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) @@ -135,15 +125,8 @@ def forward( rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size - - #! need custom code to work for specific datasets - # rhs_time = self.get_rhs_time_dict(batch.time_dict, batch.edge_index_dict, batch[entity_table].seed_time, batch, dst_entity_col, dst_table) - - # adding rhs transformer - rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, - lhs_idgnn_batch, batch_size=batch_size) - rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( ) # batch_size, num_rhs_nodes @@ -152,6 +135,9 @@ def forward( lhs_embedding_projected).flatten() embgnn_logits += embgnn_offset_logits.view(-1, 1) + #* transformer forward pass + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch, batch_size=batch_size) # Calculate idgnn logits idgnn_logits = self.head( rhs_gnn_embedding).flatten() # num_sampled_rhs @@ -170,24 +156,3 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits return embgnn_logits - - def get_rhs_time_dict( - self, - time_dict, - edge_index_dict, - seed_time, - batch_dict, - dst_entity_col, - dst_entity_table, - ): - #* what to put when transaction table is merged - edge_index = edge_index_dict['sponsors','f2p_sponsor_id', - 'sponsors_studies'] - rhs_time, _ = scatter_max( - time_dict['sponsors'][edge_index[0]], - edge_index[1]) - SECONDS_IN_A_DAY = 60 * 60 * 24 - NANOSECONDS_IN_A_DAY = 60 * 60 * 24 * 1000000000 - rhs_rel_time = seed_time[batch_dict[dst_entity_col]] - rhs_time - rhs_rel_time = rhs_rel_time / NANOSECONDS_IN_A_DAY - return rhs_rel_time diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index 5d9f3cf..7b6c7e8 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,10 +2,10 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN -from .hybrid_rhstransformer import Hybrid_RHSTransformer +from .RHSTransformer import RHSTransformer from .rerank_transformer import ReRankTransformer __all__ = classes = [ 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', - 'Hybrid_RHSTransformer', 'ReRankTransformer' + 'RHSTransformer', 'ReRankTransformer' ] diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index a9cf2ac..f820802 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -10,7 +10,7 @@ from torch_geometric.nn.encoding import PositionalEncoding -class RHSTransformer(torch.nn.Module): +class Transformer(torch.nn.Module): r"""A module to attend to rhs embeddings with a transformer. Args: in_channels (int): The number of input channels of the RHS embedding. @@ -29,7 +29,6 @@ def __init__( heads: int = 1, num_transformer_blocks: int = 1, dropout: float = 0.0, - position_encoding: str = "abs", ) -> None: super().__init__() @@ -38,15 +37,6 @@ def __init__( self.hidden_channels = hidden_channels self.lin = torch.nn.Linear(in_channels, hidden_channels) self.fc = torch.nn.Linear(hidden_channels, out_channels) - self.pe_type = position_encoding - self.pe = None - if (position_encoding == "abs"): - self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding is None): - self.pe = None - else: - raise NotImplementedError - self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( channels=hidden_channels, @@ -68,10 +58,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: """ rhs_embed = self.lin(rhs_embed) - if (self.pe_type == "abs"): - rhs_embed = rhs_embed + self.pe( - torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - # #! if we sort the index, we need to sort the rhs_embed sorted_index, sorted_idx = torch.sort(index, stable=True) index = index[sorted_idx] @@ -91,26 +77,3 @@ def inverse_permutation(self,perm): inv[perm] = torch.arange(perm.size(0), device=perm.device) return inv - -class RotaryPositionalEmbeddings(torch.nn.Module): - def __init__(self, channels, base=10000): - super().__init__() - self.channels = channels - self.base = base - self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / - channels)) - - def forward(self, x, pos=None): - seq_len = x.shape[1] - if (pos is None): - pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', pos, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - - cos = emb.cos().to(x.device) - sin = emb.sin().to(x.device) - - x1, x2 = x[..., ::2], x[..., 1::2] - rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) - - return x * cos + rotated * sin From 7380a1709f0fda2aafd4d039a9130b96e309233a Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 19:30:49 +0000 Subject: [PATCH 16/22] updating rerank_transformer --- .../relbench_link_prediction_benchmark.py | 42 ++++---- hybridgnn/nn/models/rerank_transformer.py | 98 ++++++++++++------- 2 files changed, 83 insertions(+), 57 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 6ec54be..30042eb 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -36,6 +36,7 @@ from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils import index_to_mask from torch_geometric.utils.map import map_index @@ -71,7 +72,8 @@ parser.add_argument("--result_path", type=str, default="result") args = parser.parse_args() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = "cpu" if torch.cuda.is_available(): torch.set_num_threads(1) seed_everything(args.seed) @@ -220,22 +222,21 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) - - #! continue here to debug for map_index to only get label for the topk that transformer learns - """ - # batch_size = batch[task.src_entity_table].batch_size - # target = torch.isin( - # batch[task.dst_entity_table].batch + - # batch_size * batch[task.dst_entity_table].n_id, - # src_batch + batch_size * dst_index, - # ).float() - # print (target.shape) - # quit() - # topk_labels = map_index(edge_label_index, topk_idx) - """ - loss += sparse_cross_entropy(tr_logits, edge_label_index) + num_rhs_nodes = gnn_logits.shape[1] + + #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction + batch_size = topk_idx.shape[0] + topk = topk_idx.shape[1] + idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + topk_idx = topk_idx + idx_position + + correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) + # true_label_index, mask = map_index(topk_idx, edge_label_index) + # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) + # correct_label[mask] = True + # loss += sparse_cross_entropy(tr_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - loss.backward() optimizer.step() @@ -282,11 +283,14 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_table).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: - gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - scores = torch.sigmoid(tr_logits.detach()) - + #! need to change the shape of tr_logits + scores = torch.zeros(batch_size, task.num_dst_nodes, + device=tr_logits.device) + scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index e1fb073..28ef133 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -1,9 +1,10 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Type import torch from torch import Tensor from torch_frame.data.stats import StatType from torch_geometric.data import HeteroData +from torch_frame.nn.models import ResNet from torch_geometric.nn import MLP from torch_geometric.typing import NodeType @@ -47,6 +48,8 @@ def __init__( dropout: float = 0.2, heads: int = 1, rank_topk: int = 100, + torch_frame_model_cls: Type[torch.nn.Module] = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() @@ -58,6 +61,8 @@ def __init__( }, node_to_col_stats=col_stats_dict, stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_cls=torch_frame_model_cls, + torch_frame_model_kwargs=torch_frame_model_kwargs, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=[ @@ -88,13 +93,13 @@ def __init__( self.rank_topk = rank_topk self.tr_blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( - channels=embedding_dim, + channels=embedding_dim*2, heads=heads, layer_norm=True, dropout=dropout, ) for _ in range(1) ]) - # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + self.tr_lin = torch.nn.Linear(embedding_dim*2, 1) self.channels = channels @@ -112,7 +117,7 @@ def reset_parameters(self) -> None: self.lhs_projector.reset_parameters() for block in self.tr_blocks: block.reset_parameters() - # self.tr_lin.reset_parameters() + self.tr_lin.reset_parameters() def forward( self, @@ -159,6 +164,7 @@ def forward( # Calculate idgnn logits idgnn_logits = self.head( rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from # lhs therefore, we need to incorporate this information using # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding @@ -174,22 +180,60 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits - #! let's do end to end transformer here - all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds - assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + shallow_rhs_embed = rhs_embedding.weight + transformer_logits, topk_index = self.rerank(embgnn_logits, shallow_rhs_embed, rhs_gnn_embedding, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch,lhs_embedding_projected[lhs_idgnn_batch]) + return embgnn_logits, transformer_logits, topk_index - #* rhs_gnn_embedding is significantly smaller than rhs_embed and we can't use inplace operation during backprop - #* -----> this is not global, can't replace like this - copy_tensor = torch.zeros(all_rhs_embed.shape).to(all_rhs_embed.device) - copy_tensor[rhs_idgnn_index] = rhs_gnn_embedding - final_rhs_embed = all_rhs_embed + copy_tensor - # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here - # transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), final_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) - transformer_logits, topk_index = self.rerank(embgnn_logits, final_rhs_embed, lhs_idgnn_batch, lhs_embedding_projected[lhs_idgnn_batch]) + def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch, lhs_embedding): + """ + reranks the gnn logits based on the provided gnn embeddings. + shallow_rhs_embed:[# rhs nodes, embed_dim] + + """ + embed_size = rhs_idgnn_embed.shape[1] + batch_size = gnn_logits.shape[0] + num_rhs_nodes = shallow_rhs_embed.shape[0] + + filtered_logits, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + # [batch_size, topk, embed_size] + seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) + rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index + + query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index) + id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size) + id_gnn_seq[mask] = rhs_idgnn_embed[query_rhs_idgnn_index] + + logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool) + logit_mask[mask] = True + seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) + + unique_lhs_idx = torch.unique(lhs_idgnn_batch) + lhs_uniq_embed = lhs_embedding[unique_lhs_idx] + + seq = seq.clone() + seq = seq.view(batch_size,self.rank_topk,-1) + + lhs_uniq_embed = lhs_uniq_embed.view(-1,1,embed_size) + lhs_uniq_embed = lhs_uniq_embed.expand(-1,seq.shape[1],-1) + seq = torch.cat((seq,lhs_uniq_embed), dim=-1) + + for block in self.tr_blocks: + seq = block(seq, seq) # [# nodes, topk, embed_size] + + #! just get the logit directly from transformer + seq = seq.view(-1,embed_size*2) + seq = self.tr_lin(seq) + topk_logits = seq.view(batch_size,self.rank_topk) + + _, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + return topk_logits, topk_indices + + + + - return embgnn_logits, transformer_logits, topk_index #* adding lhs embedding code not working yet # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): @@ -219,25 +263,3 @@ def forward( # for idx in range(topk_index.shape[0]): # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() # return out_logits, topk_index - - - def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): - """ - reranks the gnn logits based on the provided gnn embeddings. - rhs_gnn_embedding:[# rhs nodes, embed_dim] - """ - topk = self.rank_topk - _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) - embed_size = rhs_gnn_embedding.shape[1] - - # need input batch of size [# nodes, topk, embed_size] - top_embed = torch.stack([rhs_gnn_embedding[topk_index[idx]] for idx in range(topk_index.shape[0])]) - for block in self.tr_blocks: - tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] - - #! for top k prediction - out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) - # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) - for idx in range(topk_index.shape[0]): - out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - return out_logits, topk_index \ No newline at end of file From 7622b94f3a8d130cfc2cd2e3948cd8a39966a746 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 21:22:08 +0000 Subject: [PATCH 17/22] push current version --- benchmark/relbench_link_prediction_benchmark.py | 12 ++++++++++-- hybridgnn/nn/models/rerank_transformer.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 30042eb..e1545c8 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -230,13 +230,18 @@ def train( idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) topk_idx = topk_idx + idx_position + """ + debug if this is correct + """ correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) + numel = len(batch[task.dst_entity_table].batch) + + # true_label_index, mask = map_index(topk_idx, edge_label_index) # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) # correct_label[mask] = True # loss += sparse_cross_entropy(tr_logits, edge_label_index) - numel = len(batch[task.dst_entity_table].batch) loss.backward() optimizer.step() @@ -289,7 +294,10 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: #! need to change the shape of tr_logits scores = torch.zeros(batch_size, task.num_dst_nodes, device=tr_logits.device) - scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + tr_logits = tr_logits.detach() + for i in range(scores.shape[0]): + scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) + # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 28ef133..d7904b0 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -99,7 +99,7 @@ def __init__( dropout=dropout, ) for _ in range(1) ]) - self.tr_lin = torch.nn.Linear(embedding_dim*2, 1) + self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) self.channels = channels @@ -196,6 +196,7 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index num_rhs_nodes = shallow_rhs_embed.shape[0] filtered_logits, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + out_indices = topk_indices.clone() # [batch_size, topk, embed_size] seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index @@ -223,12 +224,15 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index #! just get the logit directly from transformer seq = seq.view(-1,embed_size*2) - seq = self.tr_lin(seq) - topk_logits = seq.view(batch_size,self.rank_topk) + seq = self.tr_lin(seq) # [batch_size, embed_size] + seq = seq.view(batch_size * self.rank_topk, embed_size) + lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) - _, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) - return topk_logits, topk_indices + tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( + dim=-1).flatten() + tr_logits = tr_logits.view(batch_size,self.rank_topk) + return tr_logits, out_indices From dc73f7c7f7739c01f4d6b9ce3ad183372f052749 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 22:38:45 +0000 Subject: [PATCH 18/22] converging to the follow implementation Please enter the commit message for your changes. Lines starting --- .../relbench_link_prediction_benchmark.py | 60 ++++++++++--------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index e1545c8..3115de5 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [100] + "rank_topk": [500], } train_search_space = { "batch_size": [128, 256, 512], @@ -178,6 +178,7 @@ def train( optimizer: torch.optim.Optimizer, loader: NeighborLoader, train_sparse_tensor: SparseTensor, + epoch:int, ) -> float: model.train() @@ -222,26 +223,23 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) - num_rhs_nodes = gnn_logits.shape[1] - - #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction - batch_size = topk_idx.shape[0] - topk = topk_idx.shape[1] - idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - topk_idx = topk_idx + idx_position - - """ - debug if this is correct - """ - correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) numel = len(batch[task.dst_entity_table].batch) - - # true_label_index, mask = map_index(topk_idx, edge_label_index) - # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) - # correct_label[mask] = True - # loss += sparse_cross_entropy(tr_logits, edge_label_index) + if (epoch > 0): + # num_rhs_nodes = gnn_logits.shape[1] + # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect + # batch_size = topk_idx.shape[0] + # topk = topk_idx.shape[1] + # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + # topk_idx = topk_idx + idx_position + # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + loss.backward() optimizer.step() @@ -278,32 +276,40 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["rhstransformer"]: out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["rerank_transformer"]: gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - #! need to change the shape of tr_logits - scores = torch.zeros(batch_size, task.num_dst_nodes, - device=tr_logits.device) - tr_logits = tr_logits.detach() - for i in range(scores.shape[0]): - scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) + + _, pred_idx = torch.topk(tr_logits.detach(), k=task.eval_k, dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + + #! to remove + # scores = torch.zeros(batch_size, task.num_dst_nodes, + # device=tr_logits.device) + # tr_logits = tr_logits.detach() + # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + + # for i in range(scores.shape[0]): + # scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: raise ValueError(f"Unsupported model type: {args.model}.") - _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) pred_list.append(pred_mini) pred = torch.cat(pred_list, dim=0).cpu().numpy() @@ -371,7 +377,7 @@ def train_and_eval_with_cfg( train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device) train_loss = train(model, optimizer, loader_dict["train"], - train_sparse_tensor) + train_sparse_tensor, epoch) optimizer.zero_grad() val_metric = test(model, loader_dict["val"], "val") From 38f9cf481375ac7a8ed20ec46cc3bd2e315efeef Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 22:41:02 +0000 Subject: [PATCH 19/22] training both from epoch 0, potentially try better ways --- .../relbench_link_prediction_benchmark.py | 27 +++++++++---------- hybridgnn/nn/models/rerank_transformer.py | 1 - 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 3115de5..13ddae0 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -225,20 +225,19 @@ def train( loss = sparse_cross_entropy(gnn_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - if (epoch > 0): - # num_rhs_nodes = gnn_logits.shape[1] - # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect - # batch_size = topk_idx.shape[0] - # topk = topk_idx.shape[1] - # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - # topk_idx = topk_idx + idx_position - # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - - #* approach with map_index - label_index, mask = map_index(topk_idx.view(-1), dst_index) - true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) - true_label[mask.view(true_label.shape)] = 1.0 - loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + # num_rhs_nodes = gnn_logits.shape[1] + # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect + # batch_size = topk_idx.shape[0] + # topk = topk_idx.shape[1] + # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + # topk_idx = topk_idx + idx_position + # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index d7904b0..874e209 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -189,7 +189,6 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index """ reranks the gnn logits based on the provided gnn embeddings. shallow_rhs_embed:[# rhs nodes, embed_dim] - """ embed_size = rhs_idgnn_embed.shape[1] batch_size = gnn_logits.shape[0] From 8e846dd2bf9dfcf6202b4f084546b006b0af54da Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Sun, 25 Aug 2024 18:28:03 +0000 Subject: [PATCH 20/22] for transformer, not training with nodes whose prediction is not within topk --- benchmark/relbench_link_prediction_benchmark.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 13ddae0..c7e3a6b 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [500], + "rank_topk": [200], } train_search_space = { "batch_size": [128, 256, 512], @@ -237,6 +237,13 @@ def train( label_index, mask = map_index(topk_idx.view(-1), dst_index) true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 + + #* empty label rows + nonzero_mask = torch.any(true_label, dim=1) + tr_logits = tr_logits[nonzero_mask] + true_label = true_label[nonzero_mask] + + #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() From 41a38cf10f1ca77a783419512af91bf14c67b21c Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 26 Aug 2024 18:01:05 +0000 Subject: [PATCH 21/22] commiting before clean up --- benchmark/relbench_link_prediction_benchmark.py | 14 +++++++------- hybridgnn/nn/models/rerank_transformer.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index c7e3a6b..663bbfd 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [200], + "rank_topk": [100], } train_search_space = { "batch_size": [128, 256, 512], @@ -232,20 +232,20 @@ def train( # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) # topk_idx = topk_idx + idx_position # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - + #* approach with map_index label_index, mask = map_index(topk_idx.view(-1), dst_index) true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 - #* empty label rows - nonzero_mask = torch.any(true_label, dim=1) - tr_logits = tr_logits[nonzero_mask] - true_label = true_label[nonzero_mask] + # #* empty label rows + # nonzero_mask = torch.any(true_label, dim=1) + # tr_logits = tr_logits[nonzero_mask] + # true_label = true_label[nonzero_mask] #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) - + loss.backward() optimizer.step() diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 874e209..1915a10 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -208,8 +208,7 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index logit_mask[mask] = True seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) - unique_lhs_idx = torch.unique(lhs_idgnn_batch) - lhs_uniq_embed = lhs_embedding[unique_lhs_idx] + lhs_uniq_embed = lhs_embedding[:batch_size] seq = seq.clone() seq = seq.view(batch_size,self.rank_topk,-1) From 3e71188a964b89c7372902e792b9e99b76f98c66 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 26 Aug 2024 18:06:26 +0000 Subject: [PATCH 22/22] clean up, removing reference to rerank transformer --- .../relbench_link_prediction_benchmark.py | 70 +---- hybridgnn/nn/models/__init__.py | 3 +- hybridgnn/nn/models/rerank_transformer.py | 267 ------------------ 3 files changed, 5 insertions(+), 335 deletions(-) delete mode 100644 hybridgnn/nn/models/rerank_transformer.py diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 663bbfd..e9cdd39 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -34,7 +34,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer from hybridgnn.utils import GloveTextEmbedding from torch_geometric.utils import index_to_mask from torch_geometric.utils.map import map_index @@ -51,7 +51,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -111,7 +111,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer]] if args.model == "idgnn": model_search_space = { @@ -156,22 +156,6 @@ "gamma_rate": [0.9, 1.0], } model_cls = RHSTransformer -elif args.model in ["rerank_transformer"]: - model_search_space = { - "encoder_channels": [64, 128, 256], - "encoder_layers": [2, 4, 8], - "channels": [64], - "embedding_dim": [64], - "norm": ["layer_norm", "batch_norm"], - "dropout": [0.1, 0.2], - "rank_topk": [100], - } - train_search_space = { - "batch_size": [128, 256, 512], - "base_lr": [0.0005, 0.01], - "gamma_rate": [0.9, 1.0], - } - model_cls = ReRankTransformer def train( model: torch.nn.Module, @@ -219,33 +203,6 @@ def train( edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - elif args.model in ["rerank_transformer"]: - gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - edge_label_index = torch.stack([src_batch, dst_index], dim=0) - loss = sparse_cross_entropy(gnn_logits, edge_label_index) - numel = len(batch[task.dst_entity_table].batch) - - # num_rhs_nodes = gnn_logits.shape[1] - # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect - # batch_size = topk_idx.shape[0] - # topk = topk_idx.shape[1] - # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - # topk_idx = topk_idx + idx_position - # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - - #* approach with map_index - label_index, mask = map_index(topk_idx.view(-1), dst_index) - true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) - true_label[mask.view(true_label.shape)] = 1.0 - - # #* empty label rows - # nonzero_mask = torch.any(true_label, dim=1) - # tr_logits = tr_logits[nonzero_mask] - # true_label = true_label[nonzero_mask] - - #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) - loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) - loss.backward() optimizer.step() @@ -294,25 +251,6 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_table).detach() scores = torch.sigmoid(out) _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) - elif args.model in ["rerank_transformer"]: - gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, - task.dst_entity_table, - task.dst_entity_col) - - _, pred_idx = torch.topk(tr_logits.detach(), k=task.eval_k, dim=1) - pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] - - #! to remove - # scores = torch.zeros(batch_size, task.num_dst_nodes, - # device=tr_logits.device) - # tr_logits = tr_logits.detach() - # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) - - # for i in range(scores.shape[0]): - # scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) - # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) - # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) - else: raise ValueError(f"Unsupported model type: {args.model}.") @@ -350,7 +288,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index 7b6c7e8..dadbc32 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -3,9 +3,8 @@ from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN from .RHSTransformer import RHSTransformer -from .rerank_transformer import ReRankTransformer __all__ = classes = [ 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', - 'RHSTransformer', 'ReRankTransformer' + 'RHSTransformer' ] diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py deleted file mode 100644 index 1915a10..0000000 --- a/hybridgnn/nn/models/rerank_transformer.py +++ /dev/null @@ -1,267 +0,0 @@ -from typing import Any, Dict, Optional, Type - -import torch -from torch import Tensor -from torch_frame.data.stats import StatType -from torch_geometric.data import HeteroData -from torch_frame.nn.models import ResNet -from torch_geometric.nn import MLP -from torch_geometric.typing import NodeType - -from hybridgnn.nn.encoder import ( - DEFAULT_STYPE_ENCODER_DICT, - HeteroEncoder, - HeteroTemporalEncoder, -) -from hybridgnn.nn.models import HeteroGraphSAGE -from torch_scatter import scatter_max -from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock -from torch_geometric.utils import to_dense_batch -from torch_geometric.utils.map import map_index - - - -class ReRankTransformer(torch.nn.Module): - r"""Implementation of ReRank Transformer model. - Args: - data (HeteroData): dataset - col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats - num_nodes (int): number of nodes, - num_layers (int): number of mp layers, - channels (int): input dimension, - embedding_dim (int): embedding dimension size, - aggr (str): aggregation type, - norm (norm): normalization type, - dropout (float): dropout rate for the transformer float, - heads (int): number of attention heads, - rank_topk (int): how many top results of gnn would be reranked,""" - def __init__( - self, - data: HeteroData, - col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], - num_nodes: int, - num_layers: int, - channels: int, - embedding_dim: int, - aggr: str = 'sum', - norm: str = 'layer_norm', - dropout: float = 0.2, - heads: int = 1, - rank_topk: int = 100, - torch_frame_model_cls: Type[torch.nn.Module] = ResNet, - torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__() - - self.encoder = HeteroEncoder( - channels=channels, - node_to_col_names_dict={ - node_type: data[node_type].tf.col_names_dict - for node_type in data.node_types - }, - node_to_col_stats=col_stats_dict, - stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, - torch_frame_model_cls=torch_frame_model_cls, - torch_frame_model_kwargs=torch_frame_model_kwargs, - ) - self.temporal_encoder = HeteroTemporalEncoder( - node_types=[ - node_type for node_type in data.node_types - if "time" in data[node_type] - ], - channels=channels, - ) - self.gnn = HeteroGraphSAGE( - node_types=data.node_types, - edge_types=data.edge_types, - channels=channels, - aggr=aggr, - num_layers=num_layers, - ) - self.head = MLP( - channels, - out_channels=1, - norm=norm, - num_layers=1, - ) - self.lhs_projector = torch.nn.Linear(channels, embedding_dim) - self.id_awareness_emb = torch.nn.Embedding(1, channels) - self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) - self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) - self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) - - self.rank_topk = rank_topk - self.tr_blocks = torch.nn.ModuleList([ - MultiheadAttentionBlock( - channels=embedding_dim*2, - heads=heads, - layer_norm=True, - dropout=dropout, - ) for _ in range(1) - ]) - self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) - - self.channels = channels - - self.reset_parameters() - - def reset_parameters(self) -> None: - self.encoder.reset_parameters() - self.temporal_encoder.reset_parameters() - self.gnn.reset_parameters() - self.head.reset_parameters() - self.id_awareness_emb.reset_parameters() - self.rhs_embedding.reset_parameters() - self.lin_offset_embgnn.reset_parameters() - self.lin_offset_idgnn.reset_parameters() - self.lhs_projector.reset_parameters() - for block in self.tr_blocks: - block.reset_parameters() - self.tr_lin.reset_parameters() - - def forward( - self, - batch: HeteroData, - entity_table: NodeType, - dst_table: NodeType, - dst_entity_col: NodeType, - ) -> Tensor: - - seed_time = batch[entity_table].seed_time - x_dict = self.encoder(batch.tf_dict) - - # Add ID-awareness to the root node - x_dict[entity_table][:seed_time.size(0 - )] += self.id_awareness_emb.weight - rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, - batch.batch_dict) - - for node_type, rel_time in rel_time_dict.items(): - x_dict[node_type] = x_dict[node_type] + rel_time - - x_dict = self.gnn( - x_dict, - batch.edge_index_dict, - ) - - batch_size = seed_time.size(0) - lhs_embedding = x_dict[entity_table][: - batch_size] # batch_size, channel - lhs_embedding_projected = self.lhs_projector(lhs_embedding) - rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel - rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs - lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size - - rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel - embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( - ) # batch_size, num_rhs_nodes - - # Model the importance of embedding-GNN prediction for each lhs node - embgnn_offset_logits = self.lin_offset_embgnn( - lhs_embedding_projected).flatten() - embgnn_logits += embgnn_offset_logits.view(-1, 1) - - # Calculate idgnn logits - idgnn_logits = self.head( - rhs_gnn_embedding).flatten() # num_sampled_rhs - - # Because we are only doing 2 hop, we are not really sampling info from - # lhs therefore, we need to incorporate this information using - # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding - idgnn_logits += ( - lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel - rhs_gnn_embedding).sum( - dim=-1).flatten() # num_sampled_rhs, channel - - # Model the importance of ID-GNN prediction for each lhs node - idgnn_offset_logits = self.lin_offset_idgnn( - lhs_embedding_projected).flatten() - idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] - - embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits - - shallow_rhs_embed = rhs_embedding.weight - transformer_logits, topk_index = self.rerank(embgnn_logits, shallow_rhs_embed, rhs_gnn_embedding, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch,lhs_embedding_projected[lhs_idgnn_batch]) - return embgnn_logits, transformer_logits, topk_index - - - def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch, lhs_embedding): - """ - reranks the gnn logits based on the provided gnn embeddings. - shallow_rhs_embed:[# rhs nodes, embed_dim] - """ - embed_size = rhs_idgnn_embed.shape[1] - batch_size = gnn_logits.shape[0] - num_rhs_nodes = shallow_rhs_embed.shape[0] - - filtered_logits, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) - out_indices = topk_indices.clone() - # [batch_size, topk, embed_size] - seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) - rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index - - query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index) - id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size) - id_gnn_seq[mask] = rhs_idgnn_embed[query_rhs_idgnn_index] - - logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool) - logit_mask[mask] = True - seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) - - lhs_uniq_embed = lhs_embedding[:batch_size] - - seq = seq.clone() - seq = seq.view(batch_size,self.rank_topk,-1) - - lhs_uniq_embed = lhs_uniq_embed.view(-1,1,embed_size) - lhs_uniq_embed = lhs_uniq_embed.expand(-1,seq.shape[1],-1) - seq = torch.cat((seq,lhs_uniq_embed), dim=-1) - - for block in self.tr_blocks: - seq = block(seq, seq) # [# nodes, topk, embed_size] - - #! just get the logit directly from transformer - seq = seq.view(-1,embed_size*2) - seq = self.tr_lin(seq) # [batch_size, embed_size] - seq = seq.view(batch_size * self.rank_topk, embed_size) - lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) - - tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( - dim=-1).flatten() - tr_logits = tr_logits.view(batch_size,self.rank_topk) - - return tr_logits, out_indices - - - - - - - #* adding lhs embedding code not working yet - # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): - # """ - # reranks the gnn logits based on the provided gnn embeddings. - # rhs_gnn_embedding:[# rhs nodes, embed_dim] - # """ - # topk = self.rank_topk - # _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) - # embed_size = rhs_gnn_embedding.shape[1] - - # # need input batch of size [# nodes, topk, embed_size] - # #! concatenate the lhs embedding with rhs embedding - # top_embed = torch.stack([torch.cat((rhs_gnn_embedding[topk_index[idx]],lhs_embedding[idx].view(1,-1).expand(self.rank_topk,-1)), dim=1) for idx in range(topk_index.shape[0])]) - # tr_embed = top_embed - # for block in self.tr_blocks: - # tr_embed = block(tr_embed, tr_embed) # [# nodes, topk, embed_size] - - # tr_embed = tr_embed.view(-1,embed_size*2) - # tr_embed = self.tr_lin(tr_embed) - # tr_embed = tr_embed.view(-1,self.rank_topk,embed_size) - - - # #! for top k prediction - # out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) - # # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) - # for idx in range(topk_index.shape[0]): - # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - # return out_logits, topk_index