From 829fee50a9f132c327e7f38f890dcaaacfe29e92 Mon Sep 17 00:00:00 2001 From: Shenyang Huang Date: Thu, 8 Aug 2024 22:06:10 +0000 Subject: [PATCH] adding initial implementation of rhs transformer to repo --- .gitignore | 1 + README.md | 3 + .../relbench_link_prediction_benchmark.py | 22 ++- hybridgnn/nn/models/__init__.py | 6 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 153 ++++++++++++++++++ hybridgnn/nn/models/transformer.py | 113 +++++++++++++ 6 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 hybridgnn/nn/models/hybrid_rhstransformer.py create mode 100644 hybridgnn/nn/models/transformer.py diff --git a/.gitignore b/.gitignore index ced7769..bbc0e94 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.xml venv/* *.out data/** +*.txt diff --git a/README.md b/README.md index e59c177..f0354b1 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn ``` @@ -31,4 +32,6 @@ pip install -e . # to run examples and benchmarks pip install -e '.[full]' + +pip install -U sentence-transformers ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 4a857fa..3a0f6ee 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,7 +30,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -43,7 +43,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]] if args.model == "idgnn": model_search_space = { @@ -127,6 +127,18 @@ "gamma_rate": [0.9, 0.95, 1.], } model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) +elif args.model in ["rhstransformer"]: + model_search_space = { + "channels": [64, 128, 256], + "embedding_dim": [64, 128, 256], + "norm": ["layer_norm", "batch_norm"] + } + train_search_space = { + "batch_size": [256, 512, 1024], + "base_lr": [0.001, 0.01], + "gamma_rate": [0.9, 0.95, 1.], + } + model_cls = Hybrid_RHSTransformer def train( @@ -164,7 +176,7 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - elif args.model in ["hybridgnn", "shallowrhsgnn"]: + elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -248,7 +260,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index c8e9ef2..f20d482 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,5 +2,9 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN +from .hybrid_rhstransformer import Hybrid_RHSTransformer -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN'] +__all__ = classes = [ + 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', + 'Hybrid_RHSTransformer' +] diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py new file mode 100644 index 0000000..56ae32e --- /dev/null +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -0,0 +1,153 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from hybridgnn.nn.models.transformer import RHSTransformer + + +class Hybrid_RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + pe: str = "abs", + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_transformer = RHSTransformer(in_channels=channels, + out_channels=channels, + hidden_channels=channels, + heads=1, dropout=0.2, + position_encoding=pe) + + self.channels = channels + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + self.rhs_transformer.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + + #! adding transformer here + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch) + + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + return embgnn_logits diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py new file mode 100644 index 0000000..dcab570 --- /dev/null +++ b/hybridgnn/nn/models/transformer.py @@ -0,0 +1,113 @@ +import torch +from torch import Tensor, nn +from torch_geometric.typing import EdgeType, NodeType +from torch.nested import nested_tensor + +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor +from torch_geometric.utils import cumsum, scatter +from torch_geometric.nn.encoding import PositionalEncoding + + +class RHSTransformer(torch.nn.Module): + r"""A module to attend to rhs embeddings with a transformer. + Args: + in_channels (int): The number of input channels of the RHS embedding. + out_channels (int): The number of output channels. + hidden_channels (int): The hidden channel dimension of the transformer. + heads (int): The number of attention heads for the transformer. + num_transformer_blocks (int): The number of transformer blocks. + dropout (float): dropout rate for the transformer + """ + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + heads: int = 1, + num_transformer_blocks: int = 1, + dropout: float = 0.0, + position_encoding: str = "abs", + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lin = torch.nn.Linear(in_channels, hidden_channels) + self.fc = torch.nn.Linear(hidden_channels, out_channels) + self.pe_type = position_encoding + self.pe = None + if (position_encoding == "abs"): + self.pe = PositionalEncoding(hidden_channels) + elif (position_encoding == "rope"): + # rotary pe for queries + self.q_pe = RotaryPositionalEmbeddings(hidden_channels) + # rotary pe for keys + self.k_pe = RotaryPositionalEmbeddings(hidden_channels) + + self.blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=hidden_channels, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(num_transformer_blocks) + ]) + + def reset_parameters(self): + for block in self.blocks: + block.reset_parameters() + self.lin.reset_parameters() + self.fc.reset_parameters() + + def forward(self, rhs_embed: Tensor, index: Tensor, + rhs_time: Tensor = None) -> Tensor: + r"""Returns the attended to rhs embeddings + """ + rhs_embed = self.lin(rhs_embed) + + if (self.pe_type == "abs"): + if (rhs_time is None): + rhs_embed = rhs_embed + self.pe( + torch.arange(rhs_embed.size(0), device=rhs_embed.device)) + else: + rhs_embed = rhs_embed + self.pe(rhs_time) + + x, mask = to_dense_batch(rhs_embed, index) + for block in self.blocks: + # apply the pe for both query and key + if (self.pe_type == "rope"): + x_q = self.q_pe(x, pos=rhs_time) + x_k = self.k_pe(x, pos=rhs_time) + else: + x_q = x + x_k = x + x = block(x_q, x_k) + x = x[mask] + x = x.view(-1, self.hidden_channels) + return self.fc(x) + + +class RotaryPositionalEmbeddings(torch.nn.Module): + def __init__(self, channels, base=10000): + super().__init__() + self.channels = channels + self.base = base + self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / + channels)) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + if (pos is None): + pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', pos, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos = emb.cos().to(x.device) + sin = emb.sin().to(x.device) + + x1, x2 = x[..., ::2], x[..., 1::2] + rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) + + return x * cos + rotated * sin