Skip to content

Commit

Permalink
Merge pull request #1 from pan-x-c/feature/ee_tune
Browse files Browse the repository at this point in the history
Release EE-Tuning
  • Loading branch information
pan-x-c authored Jan 30, 2024
2 parents 57164d8 + b696a9f commit f2fd105
Show file tree
Hide file tree
Showing 29 changed files with 1,252 additions and 131 deletions.
103 changes: 90 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
# README
# EE-LLM: Early-Exit Large Language Models

[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and currently under active development.

[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and compatible with 3D parallelism (namely data, tensor, sequence and pipeline parallelism).

![](images/ee_architecture.png)


As shown in the above figure, an early-exit LLM can convert intermediate hidden states into outputs.
During inference, the model can select adaptively one early/final exit to generate the output for each input, without running the full-model forward pass.

Our system supports two methods of training early-exit LLMs:

- Full-parameter training, which updates model parameters by optimizing a weighted sum of losses from multiple exits;
- EE-Tuning, a parameter-efficient approach that augments an existing pre-trained LLM with early-exit layers and tunes them while modules of the original LLM are frozen.

Further details about the usage and functionalities of EE-LLM are introduced in the following.



## Installation

The installation of EE-LLM is the same as Megatron-LM.
Expand All @@ -12,23 +26,23 @@ We recommand using the 22.12 version of [NGC's PyTorch container](https://catalo
For more details about the installation of Megatron-LM, please refer to Megatron-LM's [README](README_Megatron_LM.md).


## Training
## Full-parameter training

Below are several example training scripts used in our paper.


```
```shell
# train 1.3B model
./examples/early_exit/1-3B.sh
./examples/ee_training/1-3B.sh

# train 7B model
./examples/early_exit/7B.sh
./examples/ee_training/7B.sh

# train 13B model
./example/early_exit/13B.sh
./examples/ee_training/13B.sh

# train 30B model
./example/early_exit/30B.sh
./examples/ee_training/30B.sh
```


Expand All @@ -40,7 +54,7 @@ for more details about Megatron-LM's data preprocessing, please refer to [Data P
> Running the training scripts requires 16 Nvidia A100-80G GPUs or higher hardware specifications. To run them with fewer GPUs, please set the parallelism degrees therein to smaller values.

Below are the new configurations of EE-LLM compared to Megatron-LM. You can customize your own early-exit LLM by modifying these configurations.
Below are some new configurations of EE-LLM compared to Megatron-LM. You can customize your own early-exit LLM by modifying these configurations.

### Configurations for model architectures

Expand Down Expand Up @@ -76,14 +90,72 @@ Below are the new configurations of EE-LLM compared to Megatron-LM. You can cust

- `--backward-forward-ratio`: An estimate of the ratio of time consumption between backward and forward computation during training, used to automatically calculate the optimal number of inserted microbatches. Default to 2.0. [Experimental]


## EE-Tuning


> Before using EE-Tuning, please make sure that the existing LLM checkpoint is in Megatron-LM format.
> As an example, `examples/ee_tuning/convert/convert_llama_hf.sh` provides the functionality of converting the Llama 2 HuggingFace checkpoint into Megatron-LM format.

### Stage 1: initialize early-exit layers

The first step of EE-Tuning is to use `tools/checkpoint/checkpoint_converter.py` to add early-exit layers to the standard LLM checkpoint.
Example scripts can be found in the following file:

```shell
examples/ee_tuning/convert/add_exit_layers.sh
```

The relevant arguments are listed below:

- `--load-dir`: Path to the standard LLM checkpoint in Megatron-LM format.

- `--load-iteration`: The iteration number of the checkpoint to be loaded.

- `--save-dir`: Path to the output early-exit LLM checkpoint.

- `--add-exit-layer-nums`: Indices of the backbone Transformer layers that early exits are added to.

- `--use-exit-norm`: Add layer normalization (LayerNorm/RMSNorm) to the early-exit layer.

- `--use-exit-mlp`: Add a MLP to the early-exit layer.

- `--use-exit-block`: Add a Transformer layer to the early-exit layer.

- `--random-init`: Initialize model parameters of early-exit layers randomly. Otherwise, they are initialized as duplication of certain modules of the original LLM.

- `--megatron-path`: Path to EE-LLM root directory.


### Stage 2: tune early-exit layers

The second step of EE-Tuning is to tune the early-exit layers of the converted checkpoint, using scripts similar to those for [full-parameter training](#training). Below are some example scripts.

```shell
# tune Llama 2-Chat 13B with 8 exits
./examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh

# tune Llama 2-Chat 13B with 1 exit (only load the first 1/4 of the model)
./examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh
```

Below are the new parameters relevant to EE-Tuning. Other parameters are the same as those for full-parameter training.

- `--tune-exit`: Activate the functionality of EE-Tuning.

- `--tune-exit-pipeline-parallel-size`: Used to support partial checkpoint loading, only load pipeline stages whose stage numbers are not larger than this value.



## Inference

We provided an text generation server for inference of early-exit LLMs.
We provided a text generation server for inference of early-exit LLMs.
To start a server, you can use the following script.
Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism of the checkpoint.
Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism degrees of the checkpoint.

```
./example/early_exit/ee_inference_server.sh
```shell
./example/ee_inference/ee_inference_server.sh
```

After the server is started, you can use `tools/request_client.py` to send requests to the server.
Expand All @@ -93,8 +165,13 @@ Below are some parameters for early-exit LLM inference, which can be found in `t

- `early_exit_thres`: The confidence threshold used to determine whether to execute early exiting, ranging from 0.0 to 1.0.

- `exit_layers`: Only the early-exit layers listed here will be activated. If empty, all available early-exit layers will be activated.

- `print_max_prob`: If set, the inference server will print the token with the highest confidence and the confidence values at all exits.

## Checkpoints

The model checkpoints mentioned in our paper will be released soon.

## BibTeX

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ SERVER_ARGS="
--port $PORT
"

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd)
cd $MEGATRON_ROOT_PATH

torchrun $DIST_ARGS \
tools/run_early_exit_text_generation_server.py \
$SERVER_ARGS
6 changes: 5 additions & 1 deletion examples/early_exit/1-3B.sh → examples/ee_training/1-3B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,16 @@ OUTPUT_ARGS="
--log-timers-to-tracker \
--save-interval $SAVE_INTERVAL \
--eval-interval $EVAL_INTERVAL \
--eval-iters 0 \
--eval-iters 10 \
--wandb-project $PROJECT_NAME \
--wandb-group $GROUP_NAME \
--wandb-exp-name $RUN_NAME \
"

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd)
cd $MEGATRON_ROOT_PATH

torchrun $DIST_ARGS \
pretrain_early_exit_gpt.py \
$GPT_ARGS \
Expand Down
8 changes: 6 additions & 2 deletions examples/early_exit/13B.sh → examples/ee_training/13B.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

PROJECT_NAME=EE-LLM
GROUP_NAME=7B-EXIT-8-16-untie-300B
GROUP_NAME=13B-EXIT-10-20-untie-800B

RUN_NAME=`date "+%m%d-%H%M"`

Expand Down Expand Up @@ -141,12 +141,16 @@ OUTPUT_ARGS="
--log-timers-to-tracker \
--save-interval $SAVE_INTERVAL \
--eval-interval $EVAL_INTERVAL \
--eval-iters 0 \
--eval-iters 10 \
--wandb-project $PROJECT_NAME \
--wandb-group $GROUP_NAME \
--wandb-exp-name $RUN_NAME \
"

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd)
cd $MEGATRON_ROOT_PATH

torchrun $DIST_ARGS \
pretrain_early_exit_gpt.py \
$GPT_ARGS \
Expand Down
8 changes: 6 additions & 2 deletions examples/early_exit/30B.sh → examples/ee_training/30B.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

PROJECT_NAME=EE-LLM
GROUP_NAME=7B-EXIT-8-16-untie-300B
GROUP_NAME=30B-EXIT-15-30-untie-800B

RUN_NAME=`date "+%m%d-%H%M"`

Expand Down Expand Up @@ -141,12 +141,16 @@ OUTPUT_ARGS="
--log-timers-to-tracker \
--save-interval $SAVE_INTERVAL \
--eval-interval $EVAL_INTERVAL \
--eval-iters 0 \
--eval-iters 10 \
--wandb-project $PROJECT_NAME \
--wandb-group $GROUP_NAME \
--wandb-exp-name $RUN_NAME \
"

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd)
cd $MEGATRON_ROOT_PATH

torchrun $DIST_ARGS \
pretrain_early_exit_gpt.py \
$GPT_ARGS \
Expand Down
8 changes: 6 additions & 2 deletions examples/early_exit/7B.sh → examples/ee_training/7B.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

PROJECT_NAME=EE-LLM
GROUP_NAME=7B-EXIT-8-16-untie-300B
GROUP_NAME=7B-EXIT-8-16-untie-800B

RUN_NAME=`date "+%m%d-%H%M"`

Expand Down Expand Up @@ -141,12 +141,16 @@ OUTPUT_ARGS="
--log-timers-to-tracker \
--save-interval $SAVE_INTERVAL \
--eval-interval $EVAL_INTERVAL \
--eval-iters 0 \
--eval-iters 10 \
--wandb-project $PROJECT_NAME \
--wandb-group $GROUP_NAME \
--wandb-exp-name $RUN_NAME \
"

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd)
cd $MEGATRON_ROOT_PATH

torchrun $DIST_ARGS \
pretrain_early_exit_gpt.py \
$GPT_ARGS \
Expand Down
99 changes: 99 additions & 0 deletions examples/ee_tuning/convert/add_exit_layers.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/bin/bash

LOAD_DIR= # path to the converted llama checkpoint in megatron format
SAVE_DIR= # path to save the converted EE LLM checkpoint

LOAD_ITER=1
CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd)

# For llama2 13B model (40 layers)

## add an embedding only exit every 1/8 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --conversion-type add-exit \
# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \
# --megatron-path $MEGATRON_ROOT_PATH

## add an embedding-norm exit every 1/8 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --conversion-type add-exit \
# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \
# --megatron-path $MEGATRON_ROOT_PATH

## add an embedding-norm-mlp exit every 1/8 depth
python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
--load-dir $LOAD_DIR \
--save-dir $SAVE_DIR \
--load-iteration $LOAD_ITER \
--use-exit-norm \
--use-exit-mlp \
--conversion-type add-exit \
--add-exit-layer-nums 5 10 15 20 25 30 35 40 \
--megatron-path $MEGATRON_ROOT_PATH

## add an embedding-norm-layer exit every 1/8 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --use-exit-norm \
# --use-exit-block \
# --conversion-type add-exit \
# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \
# --megatron-path $MEGATRON_ROOT_PATH

## add an embedding-norm-mlp exit at 1/4 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --use-exit-norm \
# --use-exit-mlp \
# --conversion-type add-exit \
# --add-exit-layer-nums 10 \
# --megatron-path $MEGATRON_ROOT_PATH

## add an random init embedding-norm-mlp exit at 1/4 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --use-exit-norm \
# --use-exit-mlp \
# --random-init \
# --conversion-type add-exit \
# --add-exit-layer-nums 10 \
# --megatron-path $MEGATRON_ROOT_PATH

# For llama2 70B model (80 layers)

## add an embedding-norm-mlp exit every 1/8 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --use-exit-norm \
# --use-exit-mlp \
# --conversion-type add-exit \
# --add-exit-layer-nums 10 20 30 40 50 60 70 80 \
# --megatron-path $MEGATRON_ROOT_PATH

# For llama2 7B model (32 layers)

## add an embedding-norm-mlp exit every 1/8 depth
# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \
# --load-dir $LOAD_DIR \
# --save-dir $SAVE_DIR \
# --load-iteration $LOAD_ITER \
# --use-exit-norm \
# --use-exit-mlp \
# --conversion-type add-exit \
# --add-exit-layer-nums 4 8 12 16 20 24 28 32 \
# --megatron-path $MEGATRON_ROOT_PATH
22 changes: 22 additions & 0 deletions examples/ee_tuning/convert/convert_llama_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash

LOAD_DIR= # path to the llama2 huggingface checkpoint dir
SAVE_DIR= # path to save the converted megatron checkpoint
TP=1 # target tensor parallel size
PP=4 # target pipeline parallel size

TOKENIZER_PATH=${LOAD_DIR}/tokenizer.model

CUR_DIR=$(cd $(dirname "$0") && pwd)
MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd)

python $MEGATRON_ROOT_PATH/tools/checkpoint/util.py \
--model-type EarlyExitGPT \
--load-dir $LOAD_DIR \
--save-dir $SAVE_DIR \
--loader llama2_hf \
--saver megatron \
--target-tensor-parallel-size $TP \
--target-pipeline-parallel-size $PP \
--megatron-path $MEGATRON_ROOT_PATH \
--tokenizer-model $TOKENIZER_PATH
Loading

0 comments on commit f2fd105

Please sign in to comment.