Open
Description
I think in the forward pass of the TransformerEncoder a padding mask for the attention should be used.
The padding tokens need to be excluded when calculating the attention weights. This is related to Chapter 12.2.1.
See cell 33 here. See also the PyTorch docs for refernece.
It should be changed into something like this (the src_key_padding_mask needs to be True
for the values that need to be masked out):
def forward(self, input):
if self.padding_idx is not None:
mask = input != self.padding_idx
src_key_padding_mask = torch.logical_not(mask)
else:
mask = input == input
src_key_padding_mask = None
x = self.embd(input) #(B, T, D)
x = self.position(x) #(B, T, D)
#Because the resut of our code is (B, T, D), but transformers
#take input as (T, B, D), we will have to permute the order
#of the dimensions before and after
x = self.transformer(x.permute(1,0,2), src_key_padding_mask=src_key_padding_mask) #(T, B, D)
x = x.permute(1,0,2) #(B, T, D)
#average over time
context = x.sum(dim=1)/mask.sum(dim=1).unsqueeze(1)
return self.pred(self.attn(x, context, mask=mask))```
Metadata
Metadata
Assignees
Labels
No labels