Train a GPT-like model to generate French rap lyrics.
Essentially a fun and educational personal project, learning how to design and train a GPT-like architecture from scratch.
This project uses Rye. Make sure it's installed in your system.
To install all dependencies (downloading data, training etc.), run rye sync --all-features
in project directory.
This project uses French Rap Lyrics Kaggle dataset.
To download it, register your kaggle API token. See instructions here. Basically simply download and move your kaggle.json
token to ~/.kaggle/kaggle.json
.
Then run python scripts/download_data.py
.
Make sure you have access to a decent GPU, as the default model config is pretty VRAM-heavy.
From repo root, run python scripts/train.py
with an optional config
arg (by default pointed to configs/config.toml
).
The best model is tracked and saved on disk by torcheval.metrics.Perplexity
. By default, checkpoints are saved in checkpoints/<run_name>
.
NOTE: This project uses WandB to log and record experiments. If your training config specifies wandb.mode = online
, make sure you've registered your account with your API key.
Once your model is trained, you can push the checkpoint to HF using scripts/push_to_hf_hub.py
with the correct specified arguments. It will push the following three components:
model.pt
(specified argument), converted tomodel.safetensors
(using therapgpt.model.HFHubTransformerModel
mixin) for ease of inference on HF Space;config.toml
(specified argument);artists_tokens.txt
(specified argument).
These three components are required for inference (see next section).
This project uses Gradio for local and online inference.
Local inference is done using python app/app.py
script. Some additional arguments can be passed, essentially indicating wether to use the default checkpoint on HF Hub or some local checkpoint.
Online inference is served on HF through the (more or less) same Gradio app
. It automatically calls the default checkpoint on HF Hub for inference.
Since this project is mostly personal / educational (and since I'm GPU poor), it is probably not production-ready in its current state (and has no intention of being in the planned future). However, here are some interesting leads I plan on exploring:
- I noticed the style of each rapper isn't sufficiently marked. To enforce this more in model training, I want to try adding a classification head and backpropagate using logits + classification losses;
- Clean-up code / use more production-ready modules (e.g FlashAttention)
- Train in fp16
- Find a way to select multiple artists tokens (for mixing styles, could be fun)
Inspired by the great nanoGPT