JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)
Use the package manager pip to install the package in the following way:
pip install vanilla-transformer-jax
To use the entire Transformer model (encoder and decoder), you can use the following way:
from jax import random
from vtransformer import Transformer # imports Transformer class
model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used
prng = random.PRNGKey(42)
example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])
init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model
output = model.apply(init, example_input_src, example_input_trg, mask) # getting output
To use Encoder and Decoder seperately, you can do so in the following way:
encoding = model.encoder(init, example_input_src) #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder
This library is not perfect and can be improved in quite a few factors.
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.