Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xnuohz committed Dec 31, 2024
1 parent 83703b1 commit e43267c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
5 changes: 2 additions & 3 deletions test/nn/attention/test_performer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

def test_performer_attention():
x = torch.randn(1, 4, 16)
mask = torch.ones([1, 4], dtype=torch.bool)
attn = PerformerAttention(channels=16, heads=4)
out = attn(x, mask)
out = attn(x)
assert out.shape == (1, 4, 16)
assert str(attn) == ('PerformerAttention(heads=4, '
'head_channels=64 kernel=ReLU())')
'head_channels=64)')
7 changes: 4 additions & 3 deletions test/nn/attention/test_polynormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from torch_geometric.nn.attention import PolynormerAttention


def test_performer_attention():
def test_polynormer_attention():
x = torch.randn(1, 4, 16)
mask = torch.ones([1, 4], dtype=torch.bool)
attn = PolynormerAttention(channels=16, heads=4)
out = attn(x, mask)
out = attn(x)
import pdb
pdb.set_trace()
assert out.shape == (1, 4, 16)
assert str(attn) == ('PolynormerAttention(heads=4, '
'head_channels=64 kernel=ReLU())')
3 changes: 1 addition & 2 deletions torch_geometric/nn/attention/polynormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,4 @@ def reset_parameters(self):
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'heads={self.heads}, '
f'head_channels={self.head_channels} '
f'kernel={self.kernel})')
f'head_channels={self.head_channels})')

0 comments on commit e43267c

Please sign in to comment.