Skip to content

Commit

Permalink
Merge pull request PetarV-#41 from ahmad-PH/refactor_GraphAttentionLa…
Browse files Browse the repository at this point in the history
…yer_forward

Refactor creation of a_input into separate function
  • Loading branch information
Diego999 authored Sep 26, 2020
2 parents e6a8fa5 + 901a9ee commit 7fea210
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""

def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
Expand All @@ -17,31 +16,70 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
self.alpha = alpha
self.concat = concat

self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)

self.leakyrelu = nn.LeakyReLU(self.alpha)

def forward(self, input, adj):
h = torch.mm(input, self.W)
N = h.size()[0]

a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
def forward(self, h, adj):
Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
a_input = self._prepare_attentional_mechanism_input(Wh)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)
h_prime = torch.matmul(attention, Wh)

if self.concat:
return F.elu(h_prime)
else:
return h_prime

def _prepare_attentional_mechanism_input(self, Wh):
N = Wh.size()[0] # number of nodes

# Below, two matrices are created that contain embeddings in their rows in different orders.
# (e stands for embedding)
# These are the rows of the first matrix (Wh_repeated_in_chunks):
# e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN
# '-------------' -> N times '-------------' -> N times '-------------' -> N times
#
# These are the rows of the second matrix (Wh_repeated_alternating):
# e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN
# '----------------------------------------------------' -> N times
#

Wh_repeated_in_chunks = Wh.repeat(1, N).view(N * N, self.out_features)
Wh_repeated_alternating = Wh.repeat(N, 1)
# Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)

# The all_combination_matrix, created below, will look like this (|| denotes concatenation):
# e1 || e1
# e1 || e2
# e1 || e3
# ...
# e1 || eN
# e2 || e1
# e2 || e2
# e2 || e3
# ...
# e2 || eN
# ...
# eN || e1
# eN || e2
# eN || e3
# ...
# eN || eN

all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
# all_combinations_matrix.shape == (N * N, 2 * out_features)

return all_combinations_matrix.view(N, N, 2 * self.out_features)

def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

Expand Down

0 comments on commit 7fea210

Please sign in to comment.