This is the official repository for the paper Weak-to-Strong Reasoning.
We explore weak-to-strong learning for complex reasoning tasks, where a less capable model enhances the reasoning capabilities of a stronger one. Our progressive learning framework enables the strong model to autonomously refine its training data without requiring input from more advanced models or human-annotated data. The framework consists of two main stages:
- Supervised fine-tuning on a selective, small, but high-quality dataset;
- Preference optimization on contrastive samples identified by the strong model itself.
The overview of our framework is as follows:
We split the training set into train_1.jsonl
and train_2.jsonl
:
- The weak model uses
train_1.jsonl
to develop initial reasoning skills; - The strong model can only access questions from
train_2.jsonl
without ground truths.
The original data is available in data/raw
. Note that we augment the GSM8K training set with the data constructed by Abel.
For experiments closer to future scenarios on OlympicArena, we only use train_2
without ground truths. In the implementation, we use the test
split as the train_2
set, and the val
split as the test
set. Please refer to the load_data
code for more details.
train_1 | train_2 | test | |
---|---|---|---|
GSM8K | 7,000 | 7,000 | 1,319 |
MATH | 6,000 | 6,000 | 500 |
OlympicArena | - | 6,020 | 313 |
outputs/
: Solutions generated by three weak models (llama2-7b, gemma-2b, and mistral-7b) fortrain_2
.data/test/
: Processed data for evaluation (created usingdata/process.py
).data/llama_factory/
: All data used for training (including stage I: supervised fine-tuning and stage II: DPO or ORPO).
We have released LoRA adapters that have undergone two-stage weak-to-strong training on Hugging Face Hub.
We provide inference code using vllm in the src/
directory.
- Install the required packages:
pip install transformers==4.38.1 torch==2.1.2 vllm==0.3.3
- Install any other necessary dependencies.
Refer to run_gsm8k.sh
, run_math.sh
and run_olympic.sh
for zero-shot, few-shot, or temperature sampling inference. Few-shot templates can be found in src/prompt/template.py
.
For the evaluation on OlympicArena, please refer to the OlympicArena repository.
To generate the actual training data (as provided in data/llama_factory
), use src/construct_training_data.py
:
-
For stage I (supervised fine-tuning):
- Run the
construct_weak_icl_data
function to find the intersection of weak data and ICL data where final answers are consistent.
- Run the
-
For stage II (preference optimization):
- Generate sampling data with temperature=1.0.
- Use
construct_paired_data_gsm8k
andconstruct_paired_data_math
functions to create paired data.
For OlympicArena, please use judge.py in the OlympicArena repository to judge the consistency of two given responses and modify the constructing training data code accordingly.
We employ LLaMA-Factory for all model training:
- v0.5.0 for supervised fine-tuning and DPO
- v0.6.2 for ORPO support
All training data and dataset_info.json
are provided in data/llama_factory
. For detailed training instructions, please refer to the LLaMA-Factory repository. There are two arguments worth noting, we use a new vanilla
template without any specified formats, and we train on q_proj,v_proj,k_proj,o_proj,gate_proj,up_proj,down_proj
as the lora_target
.
If you find our work useful, please cite our paper:
@misc{yang2024weaktostrongreasoning,
title={Weak-to-Strong Reasoning},
author={Yuqing Yang and Yan Ma and Pengfei Liu},
year={2024},
eprint={2407.13647},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2407.13647},
}