From d6ff1698e29d26b298723a02640c989d5777b65a Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 20:47:16 +0000 Subject: [PATCH] adding transformer changes --- README.md | 1 + hybridgnn/nn/models/transformer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 622e1a1..d02d749 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10 python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 40a3e8c..502ee4e 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -72,8 +72,9 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - sorted_index, _ = torch.sort(index) - index = sorted_index + # #! if we sort the index, we need to sort the rhs_embed + # sorted_index, _ = torch.sort(index) + # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: