Skip to content

Use padding mask for attention in SimpleTransformerClassifier #13

Open
@felix-ha

Description

@felix-ha

I think in the forward pass of the TransformerEncoder a padding mask for the attention should be used.
The padding tokens need to be excluded when calculating the attention weights. This is related to Chapter 12.2.1.

See cell 33 here. See also the PyTorch docs for refernece.

It should be changed into something like this (the src_key_padding_mask needs to be True for the values that need to be masked out):

def forward(self, input):
        if self.padding_idx is not None:
            mask = input != self.padding_idx
            src_key_padding_mask = torch.logical_not(mask)
        else:
            mask = input == input 
            src_key_padding_mask = None
        x = self.embd(input) #(B, T, D)
        x = self.position(x) #(B, T, D)
        #Because the resut of our code is (B, T, D), but transformers 
        #take input as (T, B, D), we will have to permute the order 
        #of the dimensions before and after 
        x = self.transformer(x.permute(1,0,2), src_key_padding_mask=src_key_padding_mask) #(T, B, D)
        x = x.permute(1,0,2) #(B, T, D)
        #average over time
        context = x.sum(dim=1)/mask.sum(dim=1).unsqueeze(1)
        return self.pred(self.attn(x, context, mask=mask))```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions