diff --git a/.gitignore b/.gitignore index ced7769..bbc0e94 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.xml venv/* *.out data/** +*.txt diff --git a/README.md b/README.md index e59c177..d02d749 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,17 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh -python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn +python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10 +python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer +python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py) ```sh -python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn -python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn +python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 +python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4 ``` @@ -31,4 +33,6 @@ pip install -e . # to run examples and benchmarks pip install -e '.[full]' + +pip install -U sentence-transformers ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 79af142..e9cdd39 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -1,3 +1,7 @@ +""" +$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10 +""" + import argparse import json import os @@ -30,8 +34,12 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils import index_to_mask +from torch_geometric.utils.map import map_index + + TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -43,7 +51,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -64,7 +72,8 @@ parser.add_argument("--result_path", type=str, default="result") args = parser.parse_args() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = "cpu" if torch.cuda.is_available(): torch.set_num_threads(1) seed_everything(args.seed) @@ -102,7 +111,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer]] if args.model == "idgnn": model_search_space = { @@ -131,13 +140,29 @@ "gamma_rate": [0.8, 1.], } model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) - +elif args.model in ["rhstransformer"]: + model_search_space = { + "encoder_channels": [64, 128], + "encoder_layers": [2, 4], + "channels": [64, 128], + "embedding_dim": [64, 128], + "norm": ["layer_norm", "batch_norm"], + "dropout": [0.1, 0.2], + "t_encoding_type": ["fuse", "absolute"], + } + train_search_space = { + "batch_size": [128, 256], + "base_lr": [0.0005, 0.01], + "gamma_rate": [0.9, 1.0], + } + model_cls = RHSTransformer def train( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loader: NeighborLoader, train_sparse_tensor: SparseTensor, + epoch:int, ) -> float: model.train() @@ -173,6 +198,11 @@ def train( edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rhstransformer"]: + logits = model(batch, task.src_entity_table, task.dst_entity_table) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) loss.backward() optimizer.step() @@ -209,15 +239,21 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + elif args.model in ["rhstransformer"]: + out = model(batch, task.src_entity_table, + task.dst_entity_table).detach() + scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) else: raise ValueError(f"Unsupported model type: {args.model}.") - _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) pred_list.append(pred_mini) pred = torch.cat(pred_list, dim=0).cpu().numpy() @@ -252,7 +288,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 @@ -285,7 +321,7 @@ def train_and_eval_with_cfg( train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device) train_loss = train(model, optimizer, loader_dict["train"], - train_sparse_tensor) + train_sparse_tensor, epoch) optimizer.zero_grad() val_metric = test(model, loader_dict["val"], "val") diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 9844956..2a8c933 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -19,6 +19,9 @@ torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), } SECONDS_IN_A_DAY = 60 * 60 * 24 +SECONDS_IN_A_WEEK = 7 * 60 * 60 * 24 +SECONDS_IN_A_HOUR = 60 * 60 +SECONDS_IN_A_MINUTE = 60 class HeteroEncoder(torch.nn.Module): @@ -92,9 +95,16 @@ def forward( class HeteroTemporalEncoder(torch.nn.Module): - def __init__(self, node_types: List[NodeType], channels: int) -> None: + def __init__(self, node_types: List[NodeType], channels: int, + encoding_type: Optional[str] = "absolute",) -> None: + r""" + temporal encoder: + encoding_type (str, optional): which type of time encoding to use, options are ["absolute", "learnable", "fuse"] + """ super().__init__() + self.encoding_type = encoding_type # ["absolute", "fuse"] + self.encoder_dict = torch.nn.ModuleDict({ node_type: PositionalEncoding(channels) @@ -106,11 +116,19 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None: for node_type in node_types }) + if (self.encoding_type == "fuse"): + time_dim = 3 # hour, day, week + self.time_fuser = torch.nn.Linear(time_dim, channels) + def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() + if (self.encoding_type == "learnable"): + self.day_pe.reset_parameters() + elif (self.encoding_type == "fuse"): + self.time_fuser.reset_parameters() def forward( self, @@ -122,9 +140,17 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - rel_time = rel_time / SECONDS_IN_A_DAY - - x = self.encoder_dict[node_type](rel_time) + + if (self.encoding_type == "absolute"): + rel_time = rel_time / SECONDS_IN_A_DAY + x = self.encoder_dict[node_type](rel_time) + elif (self.encoding_type == "fuse"): + rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) + rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) + rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) + time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() + time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) #normalize hour, day, week into same scale + x = self.time_fuser(time_embed) x = self.lin_dict[node_type](x) out_dict[node_type] = x diff --git a/hybridgnn/nn/models/RHSTransformer.py b/hybridgnn/nn/models/RHSTransformer.py new file mode 100644 index 0000000..585d337 --- /dev/null +++ b/hybridgnn/nn/models/RHSTransformer.py @@ -0,0 +1,158 @@ +from typing import Any, Dict, Optional, Type + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from hybridgnn.nn.models.transformer import Transformer + + +class RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + dropout: float = 0.2, + heads: int = 1, + t_encoding_type: str = "absolute", + torch_frame_model_cls: Type[torch.nn.Module] = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_cls=torch_frame_model_cls, + torch_frame_model_kwargs=torch_frame_model_kwargs, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + encoding_type=t_encoding_type, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.channels = channels + + self.rhs_transformer = Transformer(in_channels=channels, + out_channels=channels, + hidden_channels=channels, + heads=heads, dropout=dropout) + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + self.rhs_transformer.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + #* transformer forward pass + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch, batch_size=batch_size) + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + return embgnn_logits diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index c8e9ef2..dadbc32 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,5 +2,9 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN +from .RHSTransformer import RHSTransformer -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN'] +__all__ = classes = [ + 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', + 'RHSTransformer' +] diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py new file mode 100644 index 0000000..f820802 --- /dev/null +++ b/hybridgnn/nn/models/transformer.py @@ -0,0 +1,79 @@ +import torch +import math +from torch import Tensor, nn +from torch_geometric.typing import EdgeType, NodeType +from torch.nested import nested_tensor + +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor +from torch_geometric.utils import cumsum, scatter +from torch_geometric.nn.encoding import PositionalEncoding + + +class Transformer(torch.nn.Module): + r"""A module to attend to rhs embeddings with a transformer. + Args: + in_channels (int): The number of input channels of the RHS embedding. + out_channels (int): The number of output channels. + hidden_channels (int): The hidden channel dimension of the transformer. + heads (int): The number of attention heads for the transformer. + num_transformer_blocks (int): The number of transformer blocks. + dropout (float): dropout rate for the transformer + position_encoding (str): type of positional encoding, + """ + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + heads: int = 1, + num_transformer_blocks: int = 1, + dropout: float = 0.0, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lin = torch.nn.Linear(in_channels, hidden_channels) + self.fc = torch.nn.Linear(hidden_channels, out_channels) + self.blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=hidden_channels, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(num_transformer_blocks) + ]) + + def reset_parameters(self): + for block in self.blocks: + block.reset_parameters() + self.lin.reset_parameters() + self.fc.reset_parameters() + + + def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: + r"""Returns the attended to rhs embeddings + """ + rhs_embed = self.lin(rhs_embed) + + # #! if we sort the index, we need to sort the rhs_embed + sorted_index, sorted_idx = torch.sort(index, stable=True) + index = index[sorted_idx] + rhs_embed = rhs_embed[sorted_idx] + reverse = self.inverse_permutation(sorted_idx) + + x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) + for block in self.blocks: + x = block(x, x) + x = x[mask] + x = x.view(-1, self.hidden_channels) + x = x[reverse] + return self.fc(x) + + def inverse_permutation(self,perm): + inv = torch.empty_like(perm) + inv[perm] = torch.arange(perm.size(0), device=perm.device) + return inv +