Skip to content

Commit

Permalink
adding transformer changes
Browse files Browse the repository at this point in the history
  • Loading branch information
andyhuang-kumo committed Aug 13, 2024
1 parent 5e904ea commit d6ff169
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10
python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
```
Expand Down
5 changes: 3 additions & 2 deletions hybridgnn/nn/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor:
rhs_embed = rhs_embed + self.pe(
torch.arange(rhs_embed.size(0), device=rhs_embed.device))

sorted_index, _ = torch.sort(index)
index = sorted_index
# #! if we sort the index, we need to sort the rhs_embed
# sorted_index, _ = torch.sort(index)
# assert torch.equal(index, sorted_index)

x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size)
for block in self.blocks:
Expand Down

0 comments on commit d6ff169

Please sign in to comment.