Skip to content

12dash/TransformerAttention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transformers & Attention

The repository is meant to understand how attention model works. I plan to implement the basic structure of attention and transformer model from the paper Attention is all you need

There exists easier way to do things

Pytorch has an implementation of the multi-head attention

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

This will take care of most of things that I have tried to do here. The goal was to understand the implementation rather than using it for a problem.

Implementation

I did the implementation step by step in the form of jupyter notebooks. You can go to the subfolder 'notebooks/' to check out individual component readme. Though I think it will be easier to follow the code in the model/ folder.

To-Do

  • Implementation of the transformer from the paper.
    • Positional Encoding
    • Multi-head Attention
    • Encoder Decoder
    • Transformer architecture
  • Basic Examples to train :
    • NLP tasks
    • Time series model to train as auto-regressive

Example : Neural Machine Translation (English -> Hindi)

I wanted to see if the transformer-model I created works on some simple tasks like translation task or not?
Link to code : Translation

Dataset

I used a Eng-Hin paired sentences that can be found at the link : Dataset. I used Hindi since I understand the language and was easier to experiment with : )
The dataset looks something like this :

It consists of around 2979 paired sentences i.e. the English sentence and its corresponding sentence in Hindi. This needed some preprocessing (I don't think it did anything with the hindi words and punctuations though) such as :

  • Removing punctuations
  • Lower case

Model

I used the transformer architecture that was being built. I added the positional encoding and some masking as well that is specific to this problem (we don't want attention to padding or the loss to use the padding)

Some parameters of the models:

  • # Stacks in encoder-decoder : 6
  • Multi-head attention : 4
  • Dimension : 512
  • Query-Key dimension : 256

Currently, running this on mps is reallyyyy slow which I think is happening because of matrix multiplication since I am able on cuda, its still faster than cpu.
Moreover, currently only the training part of the model is present where we do something called teacher forcing. In teacher forcing, instead of using the translated word i.e. in Hindi from the decoder as an input to the next translated word, the actual word is used. However, during prediction, this part becomes auto regressive. The output of the decoder is fed in the next translation of the word.

Results

The result for one of the cases (that is still from the training set) is :

The attention across the 4 heads for multi-head attention in the encoder-decoder attention is shown below :

About

An implementation of the transformer model and experimentation on some tasks

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published