Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RHS transformer #19

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
829fee5
adding initial implementation of rhs transformer to repo
andyhuang-kumo Aug 8, 2024
12a83bd
adding rhs transformer to the benchmark script
andyhuang-kumo Aug 9, 2024
a3a9779
rhs transformer upload
andyhuang-kumo Aug 12, 2024
988781b
updating tr
andyhuang-kumo Aug 12, 2024
e5faa22
running code
andyhuang-kumo Aug 12, 2024
5e904ea
running code
andyhuang-kumo Aug 12, 2024
d6ff169
adding transformer changes
andyhuang-kumo Aug 13, 2024
4884b83
permute the index, the rhs and then reverse it
andyhuang-kumo Aug 13, 2024
08041cb
removing none to replace with None
andyhuang-kumo Aug 13, 2024
29c70ad
add time fuse encoder to extract time pe
andyhuang-kumo Aug 14, 2024
f58b528
update hyperparameter options
andyhuang-kumo Aug 14, 2024
8f4966c
adding rerank_transformer
andyhuang-kumo Aug 19, 2024
90a8817
setting zeros for not used logits in rerank transformer
andyhuang-kumo Aug 19, 2024
1eb769c
adding reranker transformer
andyhuang-kumo Aug 21, 2024
ea58dd0
Merge branch 'master' into rhs_tr
andyhuang-kumo Aug 21, 2024
75445ce
updating RHS transformer code
andyhuang-kumo Aug 21, 2024
7380a17
updating rerank_transformer
andyhuang-kumo Aug 23, 2024
7622b94
push current version
andyhuang-kumo Aug 23, 2024
dc73f7c
converging to the follow implementation
andyhuang-kumo Aug 23, 2024
38f9cf4
training both from epoch 0, potentially try better ways
andyhuang-kumo Aug 23, 2024
8e846dd
for transformer, not training with nodes whose prediction is not within
andyhuang-kumo Aug 25, 2024
41a38cf
commiting before clean up
andyhuang-kumo Aug 26, 2024
3e71188
clean up, removing reference to rerank transformer
andyhuang-kumo Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ coverage.xml
venv/*
*.out
data/**
*.txt
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10
python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
```


Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py)

```sh
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4
```


Expand All @@ -31,4 +33,6 @@ pip install -e .

# to run examples and benchmarks
pip install -e '.[full]'

pip install -U sentence-transformers
```
52 changes: 44 additions & 8 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10
"""

import argparse
import json
import os
Expand Down Expand Up @@ -30,8 +34,12 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer
from hybridgnn.utils import GloveTextEmbedding
from torch_geometric.utils import index_to_mask
from torch_geometric.utils.map import map_index



TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
LINK_PREDICTION_METRIC = "link_prediction_map"
Expand All @@ -43,7 +51,7 @@
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=10,
Expand All @@ -64,7 +72,8 @@
parser.add_argument("--result_path", type=str, default="result")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
if torch.cuda.is_available():
torch.set_num_threads(1)
seed_everything(args.seed)
Expand Down Expand Up @@ -102,7 +111,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]
model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer]]

if args.model == "idgnn":
model_search_space = {
Expand Down Expand Up @@ -131,13 +140,29 @@
"gamma_rate": [0.8, 1.],
}
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)

elif args.model in ["rhstransformer"]:
model_search_space = {
"encoder_channels": [64, 128],
"encoder_layers": [2, 4],
"channels": [64, 128],
"embedding_dim": [64, 128],
"norm": ["layer_norm", "batch_norm"],
"dropout": [0.1, 0.2],
"t_encoding_type": ["fuse", "absolute"],
}
train_search_space = {
"batch_size": [128, 256],
"base_lr": [0.0005, 0.01],
"gamma_rate": [0.9, 1.0],
}
model_cls = RHSTransformer

def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loader: NeighborLoader,
train_sparse_tensor: SparseTensor,
epoch:int,
) -> float:
model.train()

Expand Down Expand Up @@ -173,6 +198,11 @@ def train(
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
elif args.model in ["rhstransformer"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
loss.backward()

optimizer.step()
Expand Down Expand Up @@ -209,15 +239,21 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
device=out.device)
scores[batch[task.dst_entity_table].batch,
batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
# Get ground-truth
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
elif args.model in ["rhstransformer"]:
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
else:
raise ValueError(f"Unsupported model type: {args.model}.")

_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
pred_list.append(pred_mini)

pred = torch.cat(pred_list, dim=0).cpu().numpy()
Expand Down Expand Up @@ -252,7 +288,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model in ["hybridgnn", "shallowrhsgnn"]:
if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
elif args.model == "idgnn":
model_cfg["out_channels"] = 1
Expand Down Expand Up @@ -285,7 +321,7 @@ def train_and_eval_with_cfg(
train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1],
device=device)
train_loss = train(model, optimizer, loader_dict["train"],
train_sparse_tensor)
train_sparse_tensor, epoch)
optimizer.zero_grad()
val_metric = test(model, loader_dict["val"], "val")

Expand Down
34 changes: 30 additions & 4 deletions hybridgnn/nn/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
}
SECONDS_IN_A_DAY = 60 * 60 * 24
SECONDS_IN_A_WEEK = 7 * 60 * 60 * 24
SECONDS_IN_A_HOUR = 60 * 60
SECONDS_IN_A_MINUTE = 60


class HeteroEncoder(torch.nn.Module):
Expand Down Expand Up @@ -92,9 +95,16 @@ def forward(


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int) -> None:
def __init__(self, node_types: List[NodeType], channels: int,
encoding_type: Optional[str] = "absolute",) -> None:
r"""
temporal encoder:
encoding_type (str, optional): which type of time encoding to use, options are ["absolute", "learnable", "fuse"]
"""
super().__init__()

self.encoding_type = encoding_type # ["absolute", "fuse"]

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
Expand All @@ -106,11 +116,19 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None:
for node_type in node_types
})

if (self.encoding_type == "fuse"):
time_dim = 3 # hour, day, week
self.time_fuser = torch.nn.Linear(time_dim, channels)

def reset_parameters(self) -> None:
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()
if (self.encoding_type == "learnable"):
self.day_pe.reset_parameters()
elif (self.encoding_type == "fuse"):
self.time_fuser.reset_parameters()

def forward(
self,
Expand All @@ -122,9 +140,17 @@ def forward(

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / SECONDS_IN_A_DAY

x = self.encoder_dict[node_type](rel_time)

if (self.encoding_type == "absolute"):
rel_time = rel_time / SECONDS_IN_A_DAY
x = self.encoder_dict[node_type](rel_time)
elif (self.encoding_type == "fuse"):
rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1)
rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1)
rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1)
time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float()
time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) #normalize hour, day, week into same scale
x = self.time_fuser(time_embed)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

Expand Down
158 changes: 158 additions & 0 deletions hybridgnn/nn/models/RHSTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Any, Dict, Optional, Type

import torch
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from hybridgnn.nn.encoder import (
DEFAULT_STYPE_ENCODER_DICT,
HeteroEncoder,
HeteroTemporalEncoder,
)
from hybridgnn.nn.models import HeteroGraphSAGE
from hybridgnn.nn.models.transformer import Transformer


class RHSTransformer(torch.nn.Module):
r"""Implementation of RHSTransformer model."""
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_nodes: int,
num_layers: int,
channels: int,
embedding_dim: int,
aggr: str = 'sum',
norm: str = 'layer_norm',
dropout: float = 0.2,
heads: int = 1,
t_encoding_type: str = "absolute",
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

self.encoder = HeteroEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT,
torch_frame_model_cls=torch_frame_model_cls,
torch_frame_model_kwargs=torch_frame_model_kwargs,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
encoding_type=t_encoding_type,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=1,
norm=norm,
num_layers=1,
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim)
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.channels = channels

self.rhs_transformer = Transformer(in_channels=channels,
out_channels=channels,
hidden_channels=channels,
heads=heads, dropout=dropout)

self.reset_parameters()

def reset_parameters(self) -> None:
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
self.id_awareness_emb.reset_parameters()
self.rhs_embedding.reset_parameters()
self.lin_offset_embgnn.reset_parameters()
self.lin_offset_idgnn.reset_parameters()
self.lhs_projector.reset_parameters()
self.rhs_transformer.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
batch_size] # batch_size, channel
lhs_embedding_projected = self.lhs_projector(lhs_embedding)
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size
rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel

embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)

#* transformer forward pass
rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding,
lhs_idgnn_batch, batch_size=batch_size)
# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits
Loading