-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding initial implementation of rhs transformer to repo
- Loading branch information
1 parent
1a336eb
commit 829fee5
Showing
6 changed files
with
292 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ coverage.xml | |
venv/* | ||
*.out | ||
data/** | ||
*.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_frame.data.stats import StatType | ||
from torch_geometric.data import HeteroData | ||
from torch_geometric.nn import MLP | ||
from torch_geometric.typing import NodeType | ||
|
||
from hybridgnn.nn.encoder import ( | ||
DEFAULT_STYPE_ENCODER_DICT, | ||
HeteroEncoder, | ||
HeteroTemporalEncoder, | ||
) | ||
from hybridgnn.nn.models import HeteroGraphSAGE | ||
from hybridgnn.nn.models.transformer import RHSTransformer | ||
|
||
|
||
class Hybrid_RHSTransformer(torch.nn.Module): | ||
r"""Implementation of RHSTransformer model.""" | ||
def __init__( | ||
self, | ||
data: HeteroData, | ||
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], | ||
num_nodes: int, | ||
num_layers: int, | ||
channels: int, | ||
embedding_dim: int, | ||
aggr: str = 'sum', | ||
norm: str = 'layer_norm', | ||
pe: str = "abs", | ||
) -> None: | ||
super().__init__() | ||
|
||
self.encoder = HeteroEncoder( | ||
channels=channels, | ||
node_to_col_names_dict={ | ||
node_type: data[node_type].tf.col_names_dict | ||
for node_type in data.node_types | ||
}, | ||
node_to_col_stats=col_stats_dict, | ||
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, | ||
) | ||
self.temporal_encoder = HeteroTemporalEncoder( | ||
node_types=[ | ||
node_type for node_type in data.node_types | ||
if "time" in data[node_type] | ||
], | ||
channels=channels, | ||
) | ||
self.gnn = HeteroGraphSAGE( | ||
node_types=data.node_types, | ||
edge_types=data.edge_types, | ||
channels=channels, | ||
aggr=aggr, | ||
num_layers=num_layers, | ||
) | ||
self.head = MLP( | ||
channels, | ||
out_channels=1, | ||
norm=norm, | ||
num_layers=1, | ||
) | ||
self.lhs_projector = torch.nn.Linear(channels, embedding_dim) | ||
|
||
self.id_awareness_emb = torch.nn.Embedding(1, channels) | ||
self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) | ||
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) | ||
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) | ||
self.rhs_transformer = RHSTransformer(in_channels=channels, | ||
out_channels=channels, | ||
hidden_channels=channels, | ||
heads=1, dropout=0.2, | ||
position_encoding=pe) | ||
|
||
self.channels = channels | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self) -> None: | ||
self.encoder.reset_parameters() | ||
self.temporal_encoder.reset_parameters() | ||
self.gnn.reset_parameters() | ||
self.head.reset_parameters() | ||
self.id_awareness_emb.reset_parameters() | ||
self.rhs_embedding.reset_parameters() | ||
self.lin_offset_embgnn.reset_parameters() | ||
self.lin_offset_idgnn.reset_parameters() | ||
self.lhs_projector.reset_parameters() | ||
self.rhs_transformer.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
batch: HeteroData, | ||
entity_table: NodeType, | ||
dst_table: NodeType, | ||
) -> Tensor: | ||
seed_time = batch[entity_table].seed_time | ||
x_dict = self.encoder(batch.tf_dict) | ||
|
||
# Add ID-awareness to the root node | ||
x_dict[entity_table][:seed_time.size(0 | ||
)] += self.id_awareness_emb.weight | ||
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, | ||
batch.batch_dict) | ||
|
||
for node_type, rel_time in rel_time_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + rel_time | ||
|
||
x_dict = self.gnn( | ||
x_dict, | ||
batch.edge_index_dict, | ||
) | ||
|
||
batch_size = seed_time.size(0) | ||
lhs_embedding = x_dict[entity_table][: | ||
batch_size] # batch_size, channel | ||
lhs_embedding_projected = self.lhs_projector(lhs_embedding) | ||
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel | ||
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs | ||
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size | ||
|
||
#! adding transformer here | ||
rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, | ||
lhs_idgnn_batch) | ||
|
||
rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel | ||
embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( | ||
) # batch_size, num_rhs_nodes | ||
|
||
# Model the importance of embedding-GNN prediction for each lhs node | ||
embgnn_offset_logits = self.lin_offset_embgnn( | ||
lhs_embedding_projected).flatten() | ||
embgnn_logits += embgnn_offset_logits.view(-1, 1) | ||
|
||
# Calculate idgnn logits | ||
idgnn_logits = self.head( | ||
rhs_gnn_embedding).flatten() # num_sampled_rhs | ||
# Because we are only doing 2 hop, we are not really sampling info from | ||
# lhs therefore, we need to incorporate this information using | ||
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding | ||
idgnn_logits += ( | ||
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel | ||
rhs_gnn_embedding).sum( | ||
dim=-1).flatten() # num_sampled_rhs, channel | ||
|
||
# Model the importance of ID-GNN prediction for each lhs node | ||
idgnn_offset_logits = self.lin_offset_idgnn( | ||
lhs_embedding_projected).flatten() | ||
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] | ||
|
||
embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits | ||
return embgnn_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
from torch import Tensor, nn | ||
from torch_geometric.typing import EdgeType, NodeType | ||
from torch.nested import nested_tensor | ||
|
||
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock | ||
from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor | ||
from torch_geometric.utils import cumsum, scatter | ||
from torch_geometric.nn.encoding import PositionalEncoding | ||
|
||
|
||
class RHSTransformer(torch.nn.Module): | ||
r"""A module to attend to rhs embeddings with a transformer. | ||
Args: | ||
in_channels (int): The number of input channels of the RHS embedding. | ||
out_channels (int): The number of output channels. | ||
hidden_channels (int): The hidden channel dimension of the transformer. | ||
heads (int): The number of attention heads for the transformer. | ||
num_transformer_blocks (int): The number of transformer blocks. | ||
dropout (float): dropout rate for the transformer | ||
""" | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
hidden_channels: int = 64, | ||
heads: int = 1, | ||
num_transformer_blocks: int = 1, | ||
dropout: float = 0.0, | ||
position_encoding: str = "abs", | ||
) -> None: | ||
super().__init__() | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.hidden_channels = hidden_channels | ||
self.lin = torch.nn.Linear(in_channels, hidden_channels) | ||
self.fc = torch.nn.Linear(hidden_channels, out_channels) | ||
self.pe_type = position_encoding | ||
self.pe = None | ||
if (position_encoding == "abs"): | ||
self.pe = PositionalEncoding(hidden_channels) | ||
elif (position_encoding == "rope"): | ||
# rotary pe for queries | ||
self.q_pe = RotaryPositionalEmbeddings(hidden_channels) | ||
# rotary pe for keys | ||
self.k_pe = RotaryPositionalEmbeddings(hidden_channels) | ||
|
||
self.blocks = torch.nn.ModuleList([ | ||
MultiheadAttentionBlock( | ||
channels=hidden_channels, | ||
heads=heads, | ||
layer_norm=True, | ||
dropout=dropout, | ||
) for _ in range(num_transformer_blocks) | ||
]) | ||
|
||
def reset_parameters(self): | ||
for block in self.blocks: | ||
block.reset_parameters() | ||
self.lin.reset_parameters() | ||
self.fc.reset_parameters() | ||
|
||
def forward(self, rhs_embed: Tensor, index: Tensor, | ||
rhs_time: Tensor = None) -> Tensor: | ||
r"""Returns the attended to rhs embeddings | ||
""" | ||
rhs_embed = self.lin(rhs_embed) | ||
|
||
if (self.pe_type == "abs"): | ||
if (rhs_time is None): | ||
rhs_embed = rhs_embed + self.pe( | ||
torch.arange(rhs_embed.size(0), device=rhs_embed.device)) | ||
else: | ||
rhs_embed = rhs_embed + self.pe(rhs_time) | ||
|
||
x, mask = to_dense_batch(rhs_embed, index) | ||
for block in self.blocks: | ||
# apply the pe for both query and key | ||
if (self.pe_type == "rope"): | ||
x_q = self.q_pe(x, pos=rhs_time) | ||
x_k = self.k_pe(x, pos=rhs_time) | ||
else: | ||
x_q = x | ||
x_k = x | ||
x = block(x_q, x_k) | ||
x = x[mask] | ||
x = x.view(-1, self.hidden_channels) | ||
return self.fc(x) | ||
|
||
|
||
class RotaryPositionalEmbeddings(torch.nn.Module): | ||
def __init__(self, channels, base=10000): | ||
super().__init__() | ||
self.channels = channels | ||
self.base = base | ||
self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / | ||
channels)) | ||
|
||
def forward(self, x, pos=None): | ||
seq_len = x.shape[1] | ||
if (pos is None): | ||
pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | ||
freqs = torch.einsum('i,j->ij', pos, self.inv_freq) | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
|
||
cos = emb.cos().to(x.device) | ||
sin = emb.sin().to(x.device) | ||
|
||
x1, x2 = x[..., ::2], x[..., 1::2] | ||
rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) | ||
|
||
return x * cos + rotated * sin |