Skip to content

SLM (Small Language Model) trained on French Rap Lyrics

Notifications You must be signed in to change notification settings

hugojarkoff/rapGPT

Repository files navigation

rapGPT

Train a GPT-like model to generate French rap lyrics.

Essentially a fun and educational personal project, learning how to design and train a GPT-like architecture from scratch.

0. Dependencies management

This project uses Rye. Make sure it's installed in your system.

To install all dependencies (downloading data, training etc.), run rye sync --all-features in project directory.

1. Data

This project uses French Rap Lyrics Kaggle dataset.

To download it, register your kaggle API token. See instructions here. Basically simply download and move your kaggle.json token to ~/.kaggle/kaggle.json.

Then run python scripts/download_data.py.

2. Train

Make sure you have access to a decent GPU, as the default model config is pretty VRAM-heavy.

From repo root, run python scripts/train.py with an optional config arg (by default pointed to configs/config.toml).

The best model is tracked and saved on disk by torcheval.metrics.Perplexity. By default, checkpoints are saved in checkpoints/<run_name>.

NOTE: This project uses WandB to log and record experiments. If your training config specifies wandb.mode = online, make sure you've registered your account with your API key.

3. Pushing to HF Hub

Once your model is trained, you can push the checkpoint to HF using scripts/push_to_hf_hub.py with the correct specified arguments. It will push the following three components:

  • model.pt (specified argument), converted to model.safetensors (using the rapgpt.model.HFHubTransformerModel mixin) for ease of inference on HF Space;
  • config.toml (specified argument);
  • artists_tokens.txt (specified argument).

These three components are required for inference (see next section).

4. Local Inference

This project uses Gradio for local and online inference.

Local inference is done using python app/app.py script. Some additional arguments can be passed, essentially indicating wether to use the default checkpoint on HF Hub or some local checkpoint.

5. Online Inference

Online inference is served on HF through the (more or less) same Gradio app. It automatically calls the default checkpoint on HF Hub for inference.

Future Works / Ideas

Since this project is mostly personal / educational (and since I'm GPU poor), it is probably not production-ready in its current state (and has no intention of being in the planned future). However, here are some interesting leads I plan on exploring:

  • I noticed the style of each rapper isn't sufficiently marked. To enforce this more in model training, I want to try adding a classification head and backpropagate using logits + classification losses;
  • Clean-up code / use more production-ready modules (e.g FlashAttention)
  • Train in fp16
  • Find a way to select multiple artists tokens (for mixing styles, could be fun)

Credits

Inspired by the great nanoGPT

About

SLM (Small Language Model) trained on French Rap Lyrics

Topics

Resources

Stars

Watchers

Forks

Languages