Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch FlashAttention2 for Llama #3

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/llama3_acc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ set -ex

# FSDP
# note: this need transformers>=4.41.0
./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --no_fa
./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --use_flash_attn
2 changes: 1 addition & 1 deletion examples/llama_acc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -ex

# FSDP
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn

# TP
# ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4
Expand Down
18 changes: 9 additions & 9 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ DP_NUM=1 # data parallelism number
PP_NUM=1 # pipeline parallelism number
TP_NUM=1 # tensor parallelism number
FSDP_NUM=1 # fsdp number
FLASH_ATTN=1 # enable flash-attn-2
DATA=./data/wikitext-2-raw-v1.json # data name or path
MODEL_NAME_OR_PATH="./hf_models/config/llama-1b" # model name or path
USE_FLASH_ATTN=1


OTHER_ARGS=""

HELP_STR=("Usage: bash examples/run.sh [-h|--help] [--accelerator {acc, cuda}] [--model MODEL_NAME_OR_PATH] \n"
"\t[--data DATASET_NAME_OR_PATH] [--mbs MICRO_BATCH_SIZE] [--max_seq_length MAX_SEQ_LENGTH] \n"
"\t[--num_train_epochs NUM_TRAIN_EPOCHS] [--max_steps MAX_TRAIN_STEPS] [--pp PP_NUM] [--tp TP_NUM] [--fsdp FSDP_NUM] \n"
"\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--no_fa] [--log_interval LOG_INTERVAL] \n"
"\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--use_flash_attn] [--log_interval LOG_INTERVAL] \n"
"\t[other args for apps/train.py] \n"
"Examples: \n"
"\tbash examples/run.sh --accelerator cuda --model ./hf_models/config/llama-7b\n"
Expand Down Expand Up @@ -125,8 +125,8 @@ while [[ $# -gt 0 ]]; do
BF16=0
shift
;;
--no_fa)
FLASH_ATTN=0
--use_flash_attn)
ACC_FLASH_ATTN=1
shift
;;
--log_interval)
Expand All @@ -150,6 +150,11 @@ OPTION_ARGS=""
[[ "$BF16" -eq 1 ]] && OPTION_ARGS+="--bf16 "
[[ "$FP16" -eq 1 ]] && OPTION_ARGS+="--fp16 "

if [[ "$ACC_FLASH_ATTN" == 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then
OPTION_ARGS+="--use_flash_attn "
export ACC_FLASH_ATTN=1
fi

if [ "$ACCELERATOR" == "cuda" ]; then
[ "$PP_NUM" -gt 1 ] && echo "Error: Pipeline Parallelism is not supported for cuda accelerator." && exit 1
[ "$TP_NUM" -gt 1 ] && echo "Error: Tensor Parallelism is not supported for cuda accelerator." && exit 1
Expand All @@ -160,11 +165,6 @@ if [ "$TP_NUM" -gt "1" ]; then
export XLA_USE_SPMD=1
fi


if [[ "$ACCELERATOR" == "acc" && "FLASH_ATTN" -eq 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then
export ACC_FLASH_ATTN=1
fi

export XLA_PERSISTENT_CACHE_PATH=./compiled_cache/

MODEL_NAME=$(basename $MODEL_NAME_OR_PATH)
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_baichuan_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_gemma_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def accelerate(self, model, loader):

def accelerate_internal(self, model, loader):
config = self.get_config()
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_glm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_gpt_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

device = lazy_device()
Expand Down
6 changes: 2 additions & 4 deletions flashmodels/accelerators/acc_llama_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def accelerate_internal(self, model, loader):
model = self.tensor_parallel(model)
return model, loader

if self.args.pp_num > 1:
# Prevent unnecessary model outputs
model.model.config.use_cache = False
model.model.config.use_cache = False
# TODO: support this in torchacc
if self.args.resume_from_checkpoint:
assert self.args.fsdp_num == self.args.world_size, \
Expand All @@ -101,7 +99,7 @@ def accelerate_internal(self, model, loader):
self.args.sp)

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)

if self.args.tp_num > 1 and self.args.pp_num > 1:
self.parallel_3d(model._get_underlay_model())
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_olmo_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader
else:
raise NotImplementedError("Currently, only FSDP is supported.")
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_qwen_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
Loading