Skip to content

Commit

Permalink
adding initial implementation of rhs transformer to repo
Browse files Browse the repository at this point in the history
  • Loading branch information
andyhuang-kumo committed Aug 8, 2024
1 parent 1a336eb commit 829fee5
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ coverage.xml
venv/*
*.out
data/**
*.txt
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
```

Expand All @@ -31,4 +32,6 @@ pip install -e .

# to run examples and benchmarks
pip install -e '.[full]'

pip install -U sentence-transformers
```
22 changes: 17 additions & 5 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer
from hybridgnn.utils import GloveTextEmbedding

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
Expand All @@ -43,7 +43,7 @@
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=10,
Expand Down Expand Up @@ -102,7 +102,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]
model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]]

if args.model == "idgnn":
model_search_space = {
Expand All @@ -127,6 +127,18 @@
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)
elif args.model in ["rhstransformer"]:
model_search_space = {
"channels": [64, 128, 256],
"embedding_dim": [64, 128, 256],
"norm": ["layer_norm", "batch_norm"]
}
train_search_space = {
"batch_size": [256, 512, 1024],
"base_lr": [0.001, 0.01],
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = Hybrid_RHSTransformer


def train(
Expand Down Expand Up @@ -164,7 +176,7 @@ def train(

loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
Expand Down Expand Up @@ -248,7 +260,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model in ["hybridgnn", "shallowrhsgnn"]:
if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
elif args.model == "idgnn":
model_cfg["out_channels"] = 1
Expand Down
6 changes: 5 additions & 1 deletion hybridgnn/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
from .idgnn import IDGNN
from .hybridgnn import HybridGNN
from .shallowrhsgnn import ShallowRHSGNN
from .hybrid_rhstransformer import Hybrid_RHSTransformer

__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN']
__all__ = classes = [
'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN',
'Hybrid_RHSTransformer'
]
153 changes: 153 additions & 0 deletions hybridgnn/nn/models/hybrid_rhstransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Any, Dict

import torch
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from hybridgnn.nn.encoder import (
DEFAULT_STYPE_ENCODER_DICT,
HeteroEncoder,
HeteroTemporalEncoder,
)
from hybridgnn.nn.models import HeteroGraphSAGE
from hybridgnn.nn.models.transformer import RHSTransformer


class Hybrid_RHSTransformer(torch.nn.Module):
r"""Implementation of RHSTransformer model."""
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_nodes: int,
num_layers: int,
channels: int,
embedding_dim: int,
aggr: str = 'sum',
norm: str = 'layer_norm',
pe: str = "abs",
) -> None:
super().__init__()

self.encoder = HeteroEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=1,
norm=norm,
num_layers=1,
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)

self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim)
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.rhs_transformer = RHSTransformer(in_channels=channels,
out_channels=channels,
hidden_channels=channels,
heads=1, dropout=0.2,
position_encoding=pe)

self.channels = channels

self.reset_parameters()

def reset_parameters(self) -> None:
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
self.id_awareness_emb.reset_parameters()
self.rhs_embedding.reset_parameters()
self.lin_offset_embgnn.reset_parameters()
self.lin_offset_idgnn.reset_parameters()
self.lhs_projector.reset_parameters()
self.rhs_transformer.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
batch_size] # batch_size, channel
lhs_embedding_projected = self.lhs_projector(lhs_embedding)
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size

#! adding transformer here
rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding,
lhs_idgnn_batch)

rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel
embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits
113 changes: 113 additions & 0 deletions hybridgnn/nn/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
from torch import Tensor, nn
from torch_geometric.typing import EdgeType, NodeType
from torch.nested import nested_tensor

from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor
from torch_geometric.utils import cumsum, scatter
from torch_geometric.nn.encoding import PositionalEncoding


class RHSTransformer(torch.nn.Module):
r"""A module to attend to rhs embeddings with a transformer.
Args:
in_channels (int): The number of input channels of the RHS embedding.
out_channels (int): The number of output channels.
hidden_channels (int): The hidden channel dimension of the transformer.
heads (int): The number of attention heads for the transformer.
num_transformer_blocks (int): The number of transformer blocks.
dropout (float): dropout rate for the transformer
"""
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
heads: int = 1,
num_transformer_blocks: int = 1,
dropout: float = 0.0,
position_encoding: str = "abs",
) -> None:
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.lin = torch.nn.Linear(in_channels, hidden_channels)
self.fc = torch.nn.Linear(hidden_channels, out_channels)
self.pe_type = position_encoding
self.pe = None
if (position_encoding == "abs"):
self.pe = PositionalEncoding(hidden_channels)
elif (position_encoding == "rope"):
# rotary pe for queries
self.q_pe = RotaryPositionalEmbeddings(hidden_channels)
# rotary pe for keys
self.k_pe = RotaryPositionalEmbeddings(hidden_channels)

self.blocks = torch.nn.ModuleList([
MultiheadAttentionBlock(
channels=hidden_channels,
heads=heads,
layer_norm=True,
dropout=dropout,
) for _ in range(num_transformer_blocks)
])

def reset_parameters(self):
for block in self.blocks:
block.reset_parameters()
self.lin.reset_parameters()
self.fc.reset_parameters()

def forward(self, rhs_embed: Tensor, index: Tensor,
rhs_time: Tensor = None) -> Tensor:
r"""Returns the attended to rhs embeddings
"""
rhs_embed = self.lin(rhs_embed)

if (self.pe_type == "abs"):
if (rhs_time is None):
rhs_embed = rhs_embed + self.pe(
torch.arange(rhs_embed.size(0), device=rhs_embed.device))
else:
rhs_embed = rhs_embed + self.pe(rhs_time)

x, mask = to_dense_batch(rhs_embed, index)
for block in self.blocks:
# apply the pe for both query and key
if (self.pe_type == "rope"):
x_q = self.q_pe(x, pos=rhs_time)
x_k = self.k_pe(x, pos=rhs_time)
else:
x_q = x
x_k = x
x = block(x_q, x_k)
x = x[mask]
x = x.view(-1, self.hidden_channels)
return self.fc(x)


class RotaryPositionalEmbeddings(torch.nn.Module):
def __init__(self, channels, base=10000):
super().__init__()
self.channels = channels
self.base = base
self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() /
channels))

def forward(self, x, pos=None):
seq_len = x.shape[1]
if (pos is None):
pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', pos, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

cos = emb.cos().to(x.device)
sin = emb.sin().to(x.device)

x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device)

return x * cos + rotated * sin

0 comments on commit 829fee5

Please sign in to comment.