Skip to content
/ MiniGPT Public

A decoder-only language model implemented in JAX .

Notifications You must be signed in to change notification settings

nlsfnr/MiniGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

94 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MiniGPT

A decoder-only language model implemented using Jax.

Quick tour

If you want to see some code, have a look at

Details and Features

MiniGPT is the result of combining some findings and ideas of the Cramming paper by Geiping et al. and LLaMA by Touvron et al..

MiniGPT has...

  • Pre-norm transformer blocks (LayerNorm could be changed to RMSNorm, but that would break my already trained models - although I could add a flag in the model's config)
  • Rotary positional embeddings
  • SwiGLU activations
  • Shared weights between embeddings and outputs

Notably, I found training to be much more stable with Gelu activations instead of SwiGLU. The latter does however lead to a lower loss.

When training, MiniGPT spins up three threads, one for the data loading and pre-processing, one for the actual training and one for telemetry functionality, such as logging and saving. Data is passed between the threads via queues to allow each thread to work asynchronously which mitigates drops in GPU utilisation, e.g. when the parameters are saved to disk.

MiniGPT supports...

  • Automatic mixed precision training
  • Gradient checkpointing
  • Multi-GPU support via data parallelism
  • Gradient accumulation for arbitrarily large batch sizes
  • Batch-size schedules
  • An AdamW optimiser
  • Complete reproducibility (models trained with the same config and seed will be exactly equal)
  • Remote telemetry with W&B
  • Streaming arbitrary datasets from the Hugging Face hub and combining them for training
  • Loading arbitrary tokenizers from disk and the Hugging Face tokenizers library
  • Full dockerisation for an easy setup on remote machines
  • Finding the maximum batch-size for a given model configuration via binary search

Using MiniGPT

Either run pip install -r requirements.txt (note that a prior installation of Jax is assumed) or use Docker with docker build -t minigpt . and then docker run --rm -it --gpus all -v $PWD:/workdir/ minigpt bash. The latter requires the Nvidia container runtime.

The scripts inside ./scripts/ should be all you need to train and run a model. To do so, follow the steps below.

  1. Create a configuration file. I recommend adapting a copy of ./configs/dev.yaml to your needs.
  2. To train a model, run (we use ./configs/dev.yaml as an example):
./scripts/train.py train \
    --config-path ./configs/dev.yaml \
    --seed 42 \
    --save-directory ./zoo/example-run/

The options should be self-explanatory. Run ./scripts/train.py train --help for an exhaustive list with help texts.

  1. To generate text with a trained model, run:
./scripts/generate.py generate \
    --load ./zoo/example-run/ \
    --seed 42 \
    --temperature 0.8 \
    --top-p 0.95 \
    --prompt "A long time ago"

Results

I trained a model with 100M parameters for 26 hours on an A100, i.e. for about 30 USD. Its config can be found in ./configs/100M.yaml. The final loss was 2.996, the dataset for training was a combination of C4, Bookcorpusopen, Wikitext, and Multi-news. The dataset did not repeat during training. The throughput was approximately 114,000 tokens/second (512 sequence length * 64 device batch size / 0.288 seconds per device batch). The texts below were generated by the model with the prompt in italics:

raccoons have been around for thousands of years. in fact, it is believed that the same ancient egyptian culture of raccoons used to live in the middle east and africa. this ancient culture was born out of the invention of the first cell phone. the first cell phone was invented in the late 19th century. today, many people use cell phones to communicate and do research. they also have the ability to communicate with their loved ones. in the early 20th century, people used the internet to communicate with their loved ones.

Hm, interesting. It clearly figured out grammar and is able to make some semantic associations, although those are highly inconsistent. E.g. the latter part of the text is broadly about communication technologies.

About

A decoder-only language model implemented in JAX .

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published