A decoder-only language model implemented using Jax.
If you want to see some code, have a look at
./minigpt/nn.py
- The model and its components../minigpt/training.py
- The training loop and loss function../minigpt/inference.py
- Methods to use pretrained models for inference../minigpt/sidecar.py
- Training utilities such as W&B logging, auto-saving etc../scripts/*.py
- Scripts exposing CLIs to interact with MiniGPT.
MiniGPT is the result of combining some findings and ideas of the Cramming paper by Geiping et al. and LLaMA by Touvron et al..
MiniGPT has...
- Pre-norm transformer blocks (LayerNorm could be changed to RMSNorm, but that would break my already trained models - although I could add a flag in the model's config)
- Rotary positional embeddings
- SwiGLU activations
- Shared weights between embeddings and outputs
Notably, I found training to be much more stable with Gelu activations instead of SwiGLU. The latter does however lead to a lower loss.
When training, MiniGPT spins up three threads, one for the data loading and pre-processing, one for the actual training and one for telemetry functionality, such as logging and saving. Data is passed between the threads via queues to allow each thread to work asynchronously which mitigates drops in GPU utilisation, e.g. when the parameters are saved to disk.
MiniGPT supports...
- Automatic mixed precision training
- Gradient checkpointing
- Multi-GPU support via data parallelism
- Gradient accumulation for arbitrarily large batch sizes
- Batch-size schedules
- An AdamW optimiser
- Complete reproducibility (models trained with the same config and seed will be exactly equal)
- Remote telemetry with W&B
- Streaming arbitrary datasets from the Hugging Face hub and combining them for training
- Loading arbitrary tokenizers from disk and the Hugging Face tokenizers library
- Full dockerisation for an easy setup on remote machines
- Finding the maximum batch-size for a given model configuration via binary search
Either run pip install -r requirements.txt
(note that a prior installation
of Jax is assumed) or use Docker
with docker build -t minigpt .
and then docker run --rm -it --gpus all -v $PWD:/workdir/ minigpt bash
. The latter requires the Nvidia container
runtime.
The scripts inside ./scripts/
should be all you need to train and run a
model. To do so, follow the steps below.
- Create a configuration file. I recommend adapting a copy of
./configs/dev.yaml
to your needs. - To train a model, run (we use
./configs/dev.yaml
as an example):
./scripts/train.py train \
--config-path ./configs/dev.yaml \
--seed 42 \
--save-directory ./zoo/example-run/
The options should be self-explanatory. Run ./scripts/train.py train --help
for an exhaustive list with help texts.
- To generate text with a trained model, run:
./scripts/generate.py generate \
--load ./zoo/example-run/ \
--seed 42 \
--temperature 0.8 \
--top-p 0.95 \
--prompt "A long time ago"
I trained a model with 100M parameters for 26 hours on an A100, i.e. for about
30 USD. Its config can be found in ./configs/100M.yaml
. The final loss was
2.996, the dataset for training was a combination of C4, Bookcorpusopen,
Wikitext, and Multi-news. The dataset did not repeat during training. The
throughput was approximately 114,000 tokens/second (512 sequence length * 64
device batch size / 0.288 seconds per device batch). The texts below were
generated by the model with the prompt in italics:
raccoons have been around for thousands of years. in fact, it is believed that the same ancient egyptian culture of raccoons used to live in the middle east and africa. this ancient culture was born out of the invention of the first cell phone. the first cell phone was invented in the late 19th century. today, many people use cell phones to communicate and do research. they also have the ability to communicate with their loved ones. in the early 20th century, people used the internet to communicate with their loved ones.
Hm, interesting. It clearly figured out grammar and is able to make some semantic associations, although those are highly inconsistent. E.g. the latter part of the text is broadly about communication technologies.