diff --git a/layers.py b/layers.py index 21bc0e6..3b0c0d3 100644 --- a/layers.py +++ b/layers.py @@ -8,7 +8,6 @@ class GraphAttentionLayer(nn.Module): """ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 """ - def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() self.dropout = dropout @@ -17,31 +16,70 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True): self.alpha = alpha self.concat = concat - self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) + self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) - self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) + self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) nn.init.xavier_uniform_(self.a.data, gain=1.414) self.leakyrelu = nn.LeakyReLU(self.alpha) - def forward(self, input, adj): - h = torch.mm(input, self.W) - N = h.size()[0] - - a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) + def forward(self, h, adj): + Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features) + a_input = self._prepare_attentional_mechanism_input(Wh) e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) zero_vec = -9e15*torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) attention = F.dropout(attention, self.dropout, training=self.training) - h_prime = torch.matmul(attention, h) + h_prime = torch.matmul(attention, Wh) if self.concat: return F.elu(h_prime) else: return h_prime + def _prepare_attentional_mechanism_input(self, Wh): + N = Wh.size()[0] # number of nodes + + # Below, two matrices are created that contain embeddings in their rows in different orders. + # (e stands for embedding) + # These are the rows of the first matrix (Wh_repeated_in_chunks): + # e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN + # '-------------' -> N times '-------------' -> N times '-------------' -> N times + # + # These are the rows of the second matrix (Wh_repeated_alternating): + # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN + # '----------------------------------------------------' -> N times + # + + Wh_repeated_in_chunks = Wh.repeat(1, N).view(N * N, self.out_features) + Wh_repeated_alternating = Wh.repeat(N, 1) + # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features) + + # The all_combination_matrix, created below, will look like this (|| denotes concatenation): + # e1 || e1 + # e1 || e2 + # e1 || e3 + # ... + # e1 || eN + # e2 || e1 + # e2 || e2 + # e2 || e3 + # ... + # e2 || eN + # ... + # eN || e1 + # eN || e2 + # eN || e3 + # ... + # eN || eN + + all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1) + # all_combinations_matrix.shape == (N * N, 2 * out_features) + + return all_combinations_matrix.view(N, N, 2 * self.out_features) + def __repr__(self): return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'