Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RHS transformer #19

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
829fee5
adding initial implementation of rhs transformer to repo
andyhuang-kumo Aug 8, 2024
12a83bd
adding rhs transformer to the benchmark script
andyhuang-kumo Aug 9, 2024
a3a9779
rhs transformer upload
andyhuang-kumo Aug 12, 2024
988781b
updating tr
andyhuang-kumo Aug 12, 2024
e5faa22
running code
andyhuang-kumo Aug 12, 2024
5e904ea
running code
andyhuang-kumo Aug 12, 2024
d6ff169
adding transformer changes
andyhuang-kumo Aug 13, 2024
4884b83
permute the index, the rhs and then reverse it
andyhuang-kumo Aug 13, 2024
08041cb
removing none to replace with None
andyhuang-kumo Aug 13, 2024
29c70ad
add time fuse encoder to extract time pe
andyhuang-kumo Aug 14, 2024
f58b528
update hyperparameter options
andyhuang-kumo Aug 14, 2024
8f4966c
adding rerank_transformer
andyhuang-kumo Aug 19, 2024
90a8817
setting zeros for not used logits in rerank transformer
andyhuang-kumo Aug 19, 2024
1eb769c
adding reranker transformer
andyhuang-kumo Aug 21, 2024
ea58dd0
Merge branch 'master' into rhs_tr
andyhuang-kumo Aug 21, 2024
75445ce
updating RHS transformer code
andyhuang-kumo Aug 21, 2024
7380a17
updating rerank_transformer
andyhuang-kumo Aug 23, 2024
7622b94
push current version
andyhuang-kumo Aug 23, 2024
dc73f7c
converging to the follow implementation
andyhuang-kumo Aug 23, 2024
38f9cf4
training both from epoch 0, potentially try better ways
andyhuang-kumo Aug 23, 2024
8e846dd
for transformer, not training with nodes whose prediction is not within
andyhuang-kumo Aug 25, 2024
41a38cf
commiting before clean up
andyhuang-kumo Aug 26, 2024
3e71188
clean up, removing reference to rerank transformer
andyhuang-kumo Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ coverage.xml
venv/*
*.out
data/**
*.txt
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
```


Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py)

```sh
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4
```


Expand All @@ -31,4 +32,6 @@ pip install -e .

# to run examples and benchmarks
pip install -e '.[full]'

pip install -U sentence-transformers
```
38 changes: 31 additions & 7 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer
from hybridgnn.utils import GloveTextEmbedding

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
Expand All @@ -43,7 +43,7 @@
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=10,
Expand Down Expand Up @@ -102,7 +102,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]
model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]]

if args.model == "idgnn":
model_search_space = {
Expand All @@ -117,16 +117,30 @@
model_cls = IDGNN
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
model_search_space = {
"channels": [64, 128, 256],
"embedding_dim": [64, 128, 256],
"channels": [64, 128],
"embedding_dim": [64, 128],
"norm": ["layer_norm", "batch_norm"]
}
train_search_space = {
"batch_size": [256, 512, 1024],
"batch_size": [256],
"base_lr": [0.001, 0.01],
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)
elif args.model in ["rhstransformer"]:
model_search_space = {
"channels": [64],
"embedding_dim": [64],
"norm": ["layer_norm"],
"dropout": [0.1, 0.2],
"pe": ["abs", "none"],
}
train_search_space = {
"batch_size": [64],
"base_lr": [0.001, 0.01, 0.0001],
"gamma_rate": [0.9, 1.0],
}
model_cls = Hybrid_RHSTransformer


def train(
Expand Down Expand Up @@ -169,6 +183,11 @@ def train(
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
elif args.model in ["rhstransformer"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
loss.backward()

optimizer.step()
Expand Down Expand Up @@ -210,6 +229,11 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
elif args.model in ["rhstransformer"]:
out = model(batch, task.src_entity_table,
task.dst_entity_table,
task.dst_entity_col).detach()
scores = torch.sigmoid(out)
else:
raise ValueError(f"Unsupported model type: {args.model}.")

Expand Down Expand Up @@ -248,7 +272,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model in ["hybridgnn", "shallowrhsgnn"]:
if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
elif args.model == "idgnn":
model_cfg["out_channels"] = 1
Expand Down
6 changes: 5 additions & 1 deletion hybridgnn/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
from .idgnn import IDGNN
from .hybridgnn import HybridGNN
from .shallowrhsgnn import ShallowRHSGNN
from .hybrid_rhstransformer import Hybrid_RHSTransformer

__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN']
__all__ = classes = [
'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN',
'Hybrid_RHSTransformer'
]
230 changes: 230 additions & 0 deletions hybridgnn/nn/models/hybrid_rhstransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Any, Dict

import torch
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from hybridgnn.nn.encoder import (
DEFAULT_STYPE_ENCODER_DICT,
HeteroEncoder,
HeteroTemporalEncoder,
)
from hybridgnn.nn.models import HeteroGraphSAGE
from hybridgnn.nn.models.transformer import RHSTransformer
from torch_scatter import scatter_max


class Hybrid_RHSTransformer(torch.nn.Module):
r"""Implementation of RHSTransformer model.
Args:
data (HeteroData): dataset
col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats
num_nodes (int): number of nodes,
num_layers (int): number of mp layers,
channels (int): input dimension,
embedding_dim (int): embedding dimension size,
aggr (str): aggregation type,
norm (norm): normalization type,
dropout (float): dropout rate for the transformer float,
heads (int): number of attention heads,
pe (str): type of positional encoding for the transformer,"""
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_nodes: int,
num_layers: int,
channels: int,
embedding_dim: int,
aggr: str = 'sum',
norm: str = 'layer_norm',
dropout: float = 0.2,
heads: int = 1,
pe: str = "abs",
) -> None:
super().__init__()

self.encoder = HeteroEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=1,
norm=norm,
num_layers=1,
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim)
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)

self.rhs_transformer = RHSTransformer(in_channels=channels,
out_channels=channels,
hidden_channels=channels,
heads=heads, dropout=dropout,
position_encoding=pe)

self.channels = channels

self.reset_parameters()

def reset_parameters(self) -> None:
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
self.id_awareness_emb.reset_parameters()
self.rhs_embedding.reset_parameters()
self.lin_offset_embgnn.reset_parameters()
self.lin_offset_idgnn.reset_parameters()
self.lhs_projector.reset_parameters()
self.rhs_transformer.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
dst_entity_col: NodeType,
) -> Tensor:
# print ("time dict has the following keys")
# print (batch.time_dict.keys())
# dict_keys(['drop_withdrawals', 'outcomes', 'outcome_analyses', 'eligibilities', 'sponsors_studies', 'facilities_studies', 'interventions_studies', 'studies', 'designs', 'reported_event_totals', 'conditions_studies'])

seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
batch_size] # batch_size, channel
lhs_embedding_projected = self.lhs_projector(lhs_embedding)
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size

#! need custom code to work for specific datasets
# rhs_time = self.get_rhs_time_dict(batch.time_dict, batch.edge_index_dict, batch[entity_table].seed_time, batch, dst_entity_col, dst_table)

# adding rhs transformer
rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding,
lhs_idgnn_batch, batch_size=batch_size)

rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel
embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits

def get_rhs_time_dict(
self,
time_dict,
edge_index_dict,
seed_time,
batch_dict,
dst_entity_col,
dst_entity_table,
):
# edge_index_dict keys
"""
dict_keys([('drop_withdrawals', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'drop_withdrawals'),
('outcomes', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'outcomes'),
('outcome_analyses', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'outcome_analyses'),
('outcome_analyses', 'f2p_outcome_id', 'outcomes'),
('outcomes', 'rev_f2p_outcome_id', 'outcome_analyses'),
('eligibilities', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'eligibilities'),
('sponsors_studies', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'sponsors_studies'),
('sponsors_studies', 'f2p_sponsor_id', 'sponsors'),
('sponsors', 'rev_f2p_sponsor_id', 'sponsors_studies'),
('facilities_studies', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'facilities_studies'),
('facilities_studies', 'f2p_facility_id', 'facilities'),
('facilities', 'rev_f2p_facility_id', 'facilities_studies'),
('interventions_studies', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'interventions_studies'),
('interventions_studies', 'f2p_intervention_id', 'interventions'),
('interventions', 'rev_f2p_intervention_id', 'interventions_studies'),
('designs', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'designs'),
('reported_event_totals', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'reported_event_totals'),
('conditions_studies', 'f2p_nct_id', 'studies'),
('studies', 'rev_f2p_nct_id', 'conditions_studies'),
('conditions_studies', 'f2p_condition_id', 'conditions'),
('conditions', 'rev_f2p_condition_id', 'conditions_studies')])
"""
#* what to put when transaction table is merged
edge_index = edge_index_dict['sponsors','f2p_sponsor_id',
'sponsors_studies']
rhs_time, _ = scatter_max(
time_dict['sponsors'][edge_index[0]],
edge_index[1])
SECONDS_IN_A_DAY = 60 * 60 * 24
NANOSECONDS_IN_A_DAY = 60 * 60 * 24 * 1000000000
rhs_rel_time = seed_time[batch_dict[dst_entity_col]] - rhs_time
rhs_rel_time = rhs_rel_time / NANOSECONDS_IN_A_DAY
return rhs_rel_time
Loading
Loading