Skip to content

Commit

Permalink
feat: update the package for main model, create dcd evaluate, update …
Browse files Browse the repository at this point in the history
…readme
  • Loading branch information
pphuc25 committed Feb 18, 2024
1 parent a87590e commit c0c7882
Show file tree
Hide file tree
Showing 454 changed files with 649 additions and 7,722 deletions.
42 changes: 31 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,89 @@ by [Phuc Phan](https://www.linkedin.com/in/pphuc/), [Hieu Tran](https://www.link


## Introduction
In our paper, we introduce Distillation Contrastive Decoding (DCD), a novel approach designed to enhance the reasoning capabilities of large language models (LLMs) during inference. DCD leverages the power of contrastive chain-of-thought prompts (CP) and distillation to improve LLMs' task performance by minimizing reasoning errors.

DCD addresses the limitations of existing Contrastive Decoding techniques, such as the need for smaller parallel models and high memory usage, offering a more efficient and scalable solution. Our extensive evaluations across various reasoning benchmarks, including arithmetic and commonsense reasoning tasks, demonstrate the superior performance of DCD, marking a significant advancement in the field of natural language processing and LLM reasoning enhancement.

<img align="center" src="assets/figure1-method.jpg" width="750">

<!-- <img align="center" src="assets/compare_methods.jpg" width="750"> -->


## Installation

To install the package, clone the repository and install the required dependencies:

```bash
git clone https://github.com/pphuc25/distillation-contrastive-decoding.git
cd distillation-contrastive-decoding
pip install -e .
```

## Quickstart
Our Distillation Contrastive Decoding apply for task generation in huggingface with beam size is 1. Currently support for 🤗 Hugging Face.

Our DCD approach is applicable for task generation in 🤗 Hugging Face with a beam size of 1. Here's a quick example of how to use it:

```python
# Import necessary libraries and functions
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from dcd import dcd_pipeline_registry, set_stop_words, create_prompt, create_prompt_student

from dcd import dcd_pipeline_registry, set_stop_words
from dcd import create_prompt, create_prompt_student

dcd_pipeline_registry() # register 'DCD' into greedy search of Hugging Face environment.
# Register 'DCD' into the greedy search of Hugging Face environment
dcd_pipeline_registry()

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("model_name", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("model_name", device_map="auto")

# Set the beam size to 1, add cot prompt for both expert and amateur model
# Set the generation configuration
generation_config = GenerationConfig(
num_beams=1,
pad_token_id=0,
eos_token_id=0,
)

# Set the alpha, beta and dropout rate
# Set the alpha, beta, and dropout rate
alpha_coef = 0.1
beta_coef = 0.8
dropout_rate = 0.2

# Define the question and format it
question = "Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"
question_formated = "Q: " + question + "\n" + "A:"
inputs = tokenizer(create_prompt(args_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

# Create input ids student
# Create input ids for the student model
inputs_student = tokenizer(create_prompt_student(args_prompt, type=type_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids_student = inputs_student["input_ids"].to(device)

# Generate
# Generate the output sequences
output_sequences = model.generate(
generation_config=generation_config,
input_ids=input_ids,

# Set the args for DCD
# Args of DCD
input_ids_student=input_ids_student,
teacher_student=True,
dropout_rate=dropout_rate,
alpha_coef=alpha_coef,
beta_coef=beta_coef,
)
# This is done! Do the rest for infer model

# Continue with your inference process...

```

## Examples

Check out [examples](./examples) of Distillation Contrastive Decoding (DCD), containing both applied dropout and dropout with quantize implementation. We welcome community contributions as well!


## DCD Evaluation

To facilitate easier evaluation or reproduction of DCD results, we have released a package framework, [DCD_eval](./dcd_eval), designed for few-shot evaluation of both arithmetic and commonsense reasoning on standard benchmarks. For more detailed information, please refer to our [paper]().


## Citation
If you find this useful in your research, please consider citing:
Expand Down
Binary file removed assets/compare_methods.png
Binary file not shown.
87 changes: 87 additions & 0 deletions dcd_eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Distillation Contrastive Decoding (DCD) Evaluation

## Overview

This package provides a method for evaluating the performance of Language Learning Models (LLMs) on various standard benchmarks. For more information about the evaluation process, please refer to [our DCD paper]().


## Installation

```bash
# If you have already done this, you can skip these steps
git clone https://github.com/pphuc25/distillation-contrastive-decoding.git
cd distillation-contrastive-decoding
pip install -e .

# Setting up the evaluation environment
cd dcd_eval
bash install_packages.sh
```

## Basic Usage

To evaluate the generative performance of a language model on a specific dataset (GSM8K or StrategyQA), use the following command:

```bash
python3 src/run_generation.py \
--model_name_or_path $model_name_or_path \
--task $task \
--ntrain $ntrain \
--seed $seed

# Alternatively, you can use the existing bash file

bash configs/combined/deepseak/quantize-strategy-deepseek-7b-base-beta08.sh
```

## Experiments

### Main Arguments

| Argument | Example | Description |
| ------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `--model_name_or_path` | `meta-llama/Llama-2-7b-hf` | Specifies the model to be used. |
| `--student_name_or_path` | `TheBloke/Llama-2-7B-AWQ` | Specifies the student model to be used. In our context, it's the quantized model.|
| `--prompt_file` | `gsm8k` | The name of the dataset to be evaluated on the test set.|
| `--constractive_prompt_student` | `4` | The types of contrastive CoT prompting for the amateur model. The number corresponds with the prompting detail in the paper (See appendix for more detail). |
| `--outfile` | `output_path.json` | The location to store the output results. |
| `--alpha_coef` | `1` | The threshold for plausibility. |
| `--beta_coef` | `27` | The strength of the amateur model compared to the expert model or the adjustment factor for the amateur penalty. |
| `--dropout_num` | `0.1` | The dropout rate of the amateur model. |

### Other Arguments

| Argument | Example | Description |
| ------------------ | ---------- | ----------------------------------------------------------------------------------------------------- |
| `--cot_flag` | `*enable*` | Add the flag text to extract the results. By default, the flag is "The answer is ". |
| `--fp16` | `*enable*` | The model will run in float 16 (with quantization on the amateur model, this setting only loads on the expert model). |
| `--bf16` | `*enable*` | The model will run in bfloat 16 (with quantization on the amateur model, this setting only loads on the expert model). |
| `--max_new_tokens` | `256` | The maximum number of tokens generated by the model. |

### Understanding `--constractive_prompt_student`

The `--constractive_prompt_student` argument accepts an integer from 1 to 4, each corresponding to a type of contrastive prompting. By specifying different types, we can adjust the decoding behavior of the amateur model.


#### Arithmetic Task (GSM8K)

| Types | Description of Types Contrastive CoT Prompting |
| ----- | ---------------------------------------------- |
| 1 | Rule-based Number Shuffle. |
| 2 | Rule-based Number Shuffle with Calculation Wrong |
| 3 | Synthetic Demonstration |

#### Commonsense Task (StrategyQA)

| Types | Description of Types Contrastive CoT Prompting |
| ----- | ---------------------------------------------- |
| 1 | Synthetic Demonstration. |


## Citation

If you find this useful in your research, please consider citing:

```
```
8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_gsm8k_llama13-2.sh

This file was deleted.

This file was deleted.

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_gsm8k_llama7-1.sh

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_gsm8k_llama7-2-chat.sh

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_gsm8k_llama7-2.sh

This file was deleted.

This file was deleted.

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_gsm8k_mistral-7b.sh

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_math_llama7-2.sh

This file was deleted.

This file was deleted.

7 changes: 0 additions & 7 deletions dcd_eval/configs/baseline/generate_500_strategy_llama13-2.sh

This file was deleted.

This file was deleted.

This file was deleted.

8 changes: 0 additions & 8 deletions dcd_eval/configs/baseline/generate_500_strategy_llama7-2.sh

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
python3 src/run_generation.py \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--max_new_tokens 256 \
--prompt_file gsm8k_500 \
--prompt_file gsm8k \
--student_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T \
--outfile outputs/gsm8k_full_llama2-13b_stu1,1b.json \
--cot_flag
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
python3 src/run_generation.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--max_new_tokens 256 \
--prompt_file gsm8k_500 \
--prompt_file gsm8k \
--student_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T \
--outfile outputs/gsm8k_full_llama2-7b_stu1,1b.json \
--cot_flag

This file was deleted.

Loading

0 comments on commit c0c7882

Please sign in to comment.