Skip to content

bhavnicksm/vanilla-transformer-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vanilla Transformer

PyPI version

JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)

Installation

Use the package manager pip to install the package in the following way:

pip install vanilla-transformer-jax

Usage

To use the entire Transformer model (encoder and decoder), you can use the following way:

from jax import random
from vtransformer import Transformer # imports Transformer class

model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used

prng = random.PRNGKey(42)

example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])

init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model

output = model.apply(init, example_input_src, example_input_trg, mask) # getting output

To use Encoder and Decoder seperately, you can do so in the following way:

encoding = model.encoder(init, example_input_src)  #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder

Contributing

This library is not perfect and can be improved in quite a few factors.

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

License

MIT