[NeurIPS 2024] MATES: Model-Aware Data Selection for Efficient Pretraining with Data Influence Models
This is the official repository for MATES: Model-Aware Data Selection for Efficient Pretraining with Data Influence Models. The implementation is mainly based on LitGPT, which is easy to begin with, use, and modify.
Python version
The code is tested on Python 3.9.17.
Install basic dependencies
pip install -r requirements.txt
We use a tokenized version of the C4 dataset in our code. Please ensure your disk has at least 500 GB of storage for this dataset. To get the training data for the initial warmup 10k steps, please run:
python src/select_data/select_data.py
- The selected data will be saved in
data/c4/pythia-410m/random/0
.
For preprocessing our reference task LAMBADA, please run:
python src/select_data/prepare_lambada.py
- The processed data will be saved in
data/lambada_openai
.
Our main experiments use 8 GPUs for parallelization.
Our pretraining is run stage by stage to facilitate the model-aware data selection. Each stage consists of 10k steps. For instance, in the initial warmup 10k steps, you can run:
model_name=pythia-410m \
method=random \
ckpt=0 \
decay=false \
bash scripts/pretrain.sh
ckpt=0
denotes we are training from scratch.
To resume the pretraining from previous steps (e.g., 10k), you can run:
model_name=pythia-410m \
method=random \
ckpt=40000 \
decay=false \
bash scripts/pretrain.sh
ckpt=40000
denotes our gradient accumulation step is 4.method=random
is the random data selection. You can replace it withmates
for MATES after the first 10k steps.
After the first 10k steps, we can start the MATES data selection process every 10k steps. One data selection process consists of four steps:
1️⃣ Get oracle data influence:
model_name=pythia-410m \
method=random \
ckpt=40000 \
bash scripts/probe_oracle_data_influence.sh
- For the 10k checkpoint,
method=random
, but for the following,method=mates
.
2️⃣ Train data influence model:
model_name=pythia-410m \
ckpt=40000 \
bash scripts/train_data_influence_model.sh
3️⃣ Predict data influence:
model_name=pythia-410m \
ckpt=40000 \
bash scripts/predict_data_influence.sh
4️⃣ Select the training data for the next 10k steps:
python src/select_data/select_data.py --model_name pythia-410m --method mates --ckpt 40000
- The selected data will be saved in
data/c4/pythia-410m/mates/40000
.
1️⃣ It is advised to run the evaluation after the decay stage for intermediate checkpoints for better stability.
model_name=pythia-410m \
method=mates \
ckpt=80000 \
decay=true \
bash scripts/pretrain.sh
2️⃣ We provide a simple evaluation example here and you can modify the parameters based on your needs.
model_name=pythia-410m \
method=mates \
ckpt=80800 \
bash scripts/eval.sh
- After running the evaluation script, you can find the results in the
results/c4/$model/$method/iter-$ckpt-ckpt/results.json
.
Please cite our paper if you use MATES in your work:
@inproceedings{yu2024mates,
title={MATES: Model-Aware Data Selection for Efficient Pretraining with Data Influence Models},
author={Yu, Zichun and Das, Spandan and Xiong, Chenyan},
booktitle={NeurIPS},
year={2024}
}