From ab6bfb4874518597a112b55e21a6241989ac7f5d Mon Sep 17 00:00:00 2001 From: qianhao Date: Tue, 9 Jul 2024 15:00:52 +0800 Subject: [PATCH 01/31] merge easycontext --- Llama3-70B.sh | 169 ++++ Llama3-8B.sh | 172 ++++ data/dataset_info.json | 8 +- examples/accelerate/ds_multi_nodes.yaml | 15 + examples/deepspeed/ds_z3_offload_config.json | 12 +- llama3_full_sft_ds3.yaml | 40 + src/llamafactory/data/collator.py | 30 +- src/llamafactory/easy_context/__init__.py | 88 ++ .../easy_context/dist_flash_attn/README.md | 11 + .../dist_flash_attn/async_communication.py | 527 ++++++++++++ .../dist_flash_attn/lightseq_async_attn.py | 743 +++++++++++++++++ .../lightseq_async_attn_varlen.py | 772 ++++++++++++++++++ .../dist_flash_attn/monkey_patch.py | 609 ++++++++++++++ .../dist_flash_attn/prepare_input.py | 72 ++ .../easy_context/ulysses_attn/monkey_patch.py | 107 +++ .../ulysses_attn/prepare_inputs.py | 80 ++ .../monkey_patch.py | 94 +++ .../zigzag_ring_attn/monkey_patch.py | 113 +++ .../zigzag_ring_attn/prepare_inputs.py | 76 ++ src/llamafactory/hparams/finetuning_args.py | 4 + src/llamafactory/train/sft/trainer.py | 151 +++- src/llamafactory/train/sft/workflow.py | 22 +- 22 files changed, 3899 insertions(+), 16 deletions(-) create mode 100644 Llama3-70B.sh create mode 100644 Llama3-8B.sh create mode 100644 examples/accelerate/ds_multi_nodes.yaml create mode 100644 llama3_full_sft_ds3.yaml create mode 100644 src/llamafactory/easy_context/__init__.py create mode 100644 src/llamafactory/easy_context/dist_flash_attn/README.md create mode 100644 src/llamafactory/easy_context/dist_flash_attn/async_communication.py create mode 100644 src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py create mode 100644 src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py create mode 100644 src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py create mode 100644 src/llamafactory/easy_context/dist_flash_attn/prepare_input.py create mode 100644 src/llamafactory/easy_context/ulysses_attn/monkey_patch.py create mode 100644 src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py create mode 100644 src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py create mode 100644 src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py create mode 100644 src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py diff --git a/Llama3-70B.sh b/Llama3-70B.sh new file mode 100644 index 0000000000..1101716720 --- /dev/null +++ b/Llama3-70B.sh @@ -0,0 +1,169 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset long_sft_128k \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +--logging_steps 10 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 2 + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-8B.sh b/Llama3-8B.sh new file mode 100644 index 0000000000..b746698d47 --- /dev/null +++ b/Llama3-8B.sh @@ -0,0 +1,172 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-1024} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--lora_target all \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset alpaca_en \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1200 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/8B_1K_bs_1_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--per_device_eval_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--dataloader_drop_last \ +--eval_steps 1001 + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/data/dataset_info.json b/data/dataset_info.json index 1d226b3adc..70261447e7 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -8,6 +8,12 @@ "alpaca_zh_demo": { "file_name": "alpaca_zh_demo.json" }, + "long_sft_32k": { + "file_name": "sample_long_sft_32k_48M.json" + }, + "long_sft_128k": { + "file_name": "sample_long_sft_128k.parquet" + }, "glaive_toolcall_en_demo": { "file_name": "glaive_toolcall_en_demo.json", "formatting": "sharegpt", @@ -551,4 +557,4 @@ }, "folder": "python" } -} \ No newline at end of file +} diff --git a/examples/accelerate/ds_multi_nodes.yaml b/examples/accelerate/ds_multi_nodes.yaml new file mode 100644 index 0000000000..0b465fae9a --- /dev/null +++ b/examples/accelerate/ds_multi_nodes.yaml @@ -0,0 +1,15 @@ +debug: false +deepspeed_config: + deepspeed_config_file: examples/deepspeed/ds_z3_offload_config.json + deepspeed_multinode_launcher: standard + zero3_init_flag: true +distributed_type: DEEPSPEED +num_processes: 16 +downcast_bf16: 'no' +main_training_function: main +rdzv_backend: c10d +same_network: false +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/deepspeed/ds_z3_offload_config.json b/examples/deepspeed/ds_z3_offload_config.json index 026aabbcda..b00b8bc72e 100644 --- a/examples/deepspeed/ds_z3_offload_config.json +++ b/examples/deepspeed/ds_z3_offload_config.json @@ -17,13 +17,8 @@ }, "zero_optimization": { "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, "offload_param": { - "device": "cpu", - "pin_memory": true + "device": "cpu" }, "overlap_comm": true, "contiguous_gradients": true, @@ -34,5 +29,6 @@ "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true - } -} \ No newline at end of file + }, + "steps_per_print":1 +} diff --git a/llama3_full_sft_ds3.yaml b/llama3_full_sft_ds3.yaml new file mode 100644 index 0000000000..c37060276e --- /dev/null +++ b/llama3_full_sft_ds3.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +### method +stage: sft +do_train: true +finetuning_type: full +parallel_mode: dist_flash_attn +deepspeed: examples/deepspeed/ds_z3_offload_config.json + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 2 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +fp16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1dc8dd8d38..3abf6a5a1a 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -3,7 +3,8 @@ import torch from transformers import DataCollatorForSeq2Seq - +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union +from llamafactory.easy_context import prepare_seq_parallel_sft_inputs @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): @@ -79,3 +80,30 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor batch["kto_tags"] = torch.tensor(kto_tags) return batch + +@dataclass +class SeqParallelDataCollator(DataCollatorForSeq2Seq): + r""" + Data collator for sequence parallel. + """ + seq_algo: str = "data_parallel" + rank: int = 0 + world_size: int = 8 + device: Optional[Any] = None + + def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: + batch = super().__call__(features, return_tensors) + if self.seq_algo == "data_parallel": + return batch + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=self.rank, + world_size=self.world_size, + device=self.device) + return batch diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py new file mode 100644 index 0000000000..687c018de2 --- /dev/null +++ b/src/llamafactory/easy_context/__init__.py @@ -0,0 +1,88 @@ +from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs, prepare_dist_flash_attn_sft_inputs +from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama +from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs +from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama +from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral +from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch +from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs +from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama +import torch +import torch.nn.functional as F + +def prepare_seq_parallel_inputs( + seq_algo, input_ids, position_ids, target_ids, rank, world_size, device +): + if seq_algo == "zigzag_ring_attn": + return prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "dist_flash_attn": + return prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "ulysses_attn": + return prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "data_parallel": + return { + "local_input_ids": input_ids.to(device), + "local_position_ids": position_ids.to(device), + "local_target_ids": target_ids.to(device), + } + else: + raise ValueError(f"Invalid seq_algo: {seq_algo}") + +def prepare_seq_parallel_sft_inputs( + seq_algo, input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + shift_labels = F.pad(labels, [0, 1], 'constant', -100)[:, 1:] + if seq_algo == "zigzag_ring_attn": + return prepare_zigzag_ring_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "dist_flash_attn": + return prepare_dist_flash_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "ulysses_attn": + return prepare_ulysses_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "data_parallel": + return { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "target_ids": labels, + } + else: + raise ValueError(f"Invalid seq_algo: {seq_algo}") + +def apply_seq_parallel_monkey_patch( + seq_algo, model +): + assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" + assert model in ["llama", "mistral"], f"Invalid model: {model}" + if seq_algo == "data_parallel": + return + elif seq_algo == "zigzag_ring_attn" and model == "llama": + apply_zigzag_ring_attn_monkey_patch_llama() + elif seq_algo == "zigzag_ring_attn" and model == "mistral": + apply_zigzag_ring_attn_monkey_patch_mistral() + elif seq_algo == "dist_flash_attn" and model == "llama": + apply_dist_flash_attn_monkey_patch_llama() + elif seq_algo == "ulysses_attn" and model == "llama": + apply_ulysses_attn_monkey_patch_llama() + else: + raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") + +def prepare_dataloader(seq_algo, dataloader, acclerator): + if seq_algo == "data_parallel": + return acclerator.prepare(dataloader) + else: + return dataloader diff --git a/src/llamafactory/easy_context/dist_flash_attn/README.md b/src/llamafactory/easy_context/dist_flash_attn/README.md new file mode 100644 index 0000000000..2025265c3e --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/README.md @@ -0,0 +1,11 @@ +# LightSeq +Taken from https://github.com/RulinShao/LightSeq. All credits to the authors. + +``` +@article{li2023lightseq, + title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, + author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, + journal={arXiv preprint arXiv:2310.03294}, + year={2023} +} +``` \ No newline at end of file diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py new file mode 100644 index 0000000000..610080ea3b --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -0,0 +1,527 @@ +import threading +import math +import os + +import torch +import torch.distributed as dist +from torch.distributed import batch_isend_irecv, P2POp, isend, irecv + +# Sequence parallel group that the current rank belongs to. +_SEQUENCE_PARALLEL_GROUP = None + +# These values enable us to change the sequence parallel sizes on the fly. +_SEQUENCE_PARALLEL_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +# Global buffer for P2P +_PEER_Q = None +_PEER_K = None +_PEER_V = None +_PEER_M = None +_PEER_L = None +_PEER_O = None +_PEER_Q_BWD = None +_PEER_K_BWD = None +_PEER_V_BWD = None +_PEER_O_BWD = None + +_DELTA_DQ = None +_PEER_L = None +_DELTA_DK = None +_DELTA_DV = None +_DK_DELTA_FROM_PEER = None +_DV_DELTA_FROM_PEER = None +_PEER_DO = None + + +_fwd_send_volume = 0 +_fwd_recv_volume = 0 +_bwd_send_volume = 0 +_bwd_recv_volume = 0 + +def initialize_distributed(): + if dist.is_initialized(): + if dist.get_rank() == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + else: + if int(os.environ["RANK"]) == 0: + print("Initializing Torch distributed.") + dist.init_process_group(backend="nccl") + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + global_world_size = dist.get_world_size() + torch.cuda.set_device(dist.get_rank() % local_world_size) + + _initialize_sequence_parallel() + # create_nccl_communicators() + +def _initialize_sequence_parallel(sequence_parallel_size=None): + # Get world size and rank. Ensure some consistencies. + assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if sequence_parallel_size is None: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + + rank = torch.distributed.get_rank() + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_RANK + global _SEQUENCE_PARALLEL_SIZE + + assert ( + _SEQUENCE_PARALLEL_GROUP is None + ), 'sequence parallel group is already initialized' + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + _SEQUENCE_PARALLEL_RANK = ranks.index(rank) + _SEQUENCE_PARALLEL_SIZE = len(ranks) + + if dist.get_rank() == 0: + print("************ Finish sequence pralell group Initialization. ***********") + # _set_global_memory_buffer() + +def maybe_get_set_global_memory_buffer(q, k, v, m, l, o): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + if _PEER_Q is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer.") + except: + print("Initializing global memoery buffer.") + _PEER_Q = [torch.empty_like(q) for _ in range(2)] + _PEER_K = [torch.empty_like(k) for _ in range(2)] + _PEER_V = [torch.empty_like(v) for _ in range(2)] + _PEER_M = [torch.empty_like(m) for _ in range(2)] + _PEER_L = [torch.empty_like(l) for _ in range(2)] + _PEER_O = [torch.empty_like(o) for _ in range(2)] + + return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + +def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do): + global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER,_PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO + if _DELTA_DQ is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer for backward.") + except: + print("Initializing global memoery buffer for backward.") + _DELTA_DQ = [torch.empty_like(dq) for _ in range(2)] + _DELTA_DK = [torch.empty_like(dk) for _ in range(2)] + _DELTA_DV = [torch.empty_like(dv) for _ in range(2)] + _PEER_L = [torch.empty_like(L) for _ in range(2)] + + _DK_DELTA_FROM_PEER = torch.empty_like(dk) + _DV_DELTA_FROM_PEER = torch.empty_like(dv) + + # may already be initailized in the forward call. + # current forward and backward needs a transpose in q's format + _PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)] + _PEER_K_BWD = [torch.empty_like(k) for _ in range(2)] + _PEER_V_BWD = [torch.empty_like(v) for _ in range(2)] + _PEER_O_BWD = [torch.empty_like(o) for _ in range(2)] + + _PEER_DO = [torch.empty_like(do) for _ in range(2)] + + return _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO + +def reset_global_memory_buffer(): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO + _PEER_Q = None + _PEER_K = None + _PEER_V = None + _PEER_M = None + _PEER_L = None + _PEER_O = None + + _DELTA_DQ = None + _PEER_L = None + _DELTA_DK = None + _DELTA_DV = None + _DK_DELTA_FROM_PEER = None + _DV_DELTA_FROM_PEER = None + _PEER_DO = None + +# Pytorch defers the creation of nccl communicators to the first P2P call, +# We manually create them so the first isend does not hang without an irecv. +# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138 +# Only support even number of GPUs. +def create_nccl_communicators(): + seq_rank = get_sequence_parallel_rank() + seq_group = get_sequence_parallel_group() + + empty_tensor = torch.empty(1,).cuda() + empty_tensor_2 = torch.empty(1,).cuda() + if torch.distributed.get_rank() % 2 == 0: + # sender + op1 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) + op2 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) + #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + dist.batch_isend_irecv([op1, op2]) + else: + # receiver + op1 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) + op2 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) + #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + handles = dist.batch_isend_irecv([op1, op2]) + #req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group) + dist.all_reduce(empty_tensor, group=seq_group) + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + #global _SEQUENCE_PARALLEL_GROUP + assert ( + _SEQUENCE_PARALLEL_GROUP is not None + ), 'sequence parallel group is not initialized' + return _SEQUENCE_PARALLEL_GROUP + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + return torch.distributed.get_rank(group=get_sequence_parallel_group()) + +def get_sequence_parallel_size(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_SIZE + if _SEQUENCE_PARALLEL_SIZE is not None: + return _SEQUENCE_PARALLEL_SIZE + return torch.distributed.get_world_size(group=get_sequence_parallel_group()) + +def destroy_sequence_parallel(): + """Set the groups to none.""" + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None + +# whether this is the last time the kernel being called +def is_last_time(time_step): + # e.g. on a 8-GPU setup: + # R=0: 0 + # R=1: 1 + # R=2: 2 + # R=3: 3 + # R=4: 4, 5, 6, 7 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank <= seq_world_size // 2: # no one helps these ranks + rank_finish_time = seq_rank + else: + rank_finish_time = seq_world_size // 2 + return rank_finish_time == time_step + +# Whether the current time step is computing for local q +def is_compute_for_local_query(time_step): + # R=3,4,5,6,7: Yes + # R=0: 0 + # R=1: 0, 1 + # R=2: 0, 1, 2 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank >= min(seq_world_size // 2, time_step): + return True + return False + +# Whether the current time step is idle +def is_idle(time_step): + # 0, 1, 2, 3: 4 + # 4, 5, 6, 7: No + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2: + return True + return False + +# Whether the current time step needs to synchronize with a remote computed result +def is_sync_from_remote(time_step): + # R=0, 1, 2, 3, 4: No + # R=5: 4 + # R=6: 3, 4 + # R=7: 2, 3, 4 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank > max(seq_world_size // 2, seq_world_size - time_step): + return True + return False + +def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, + k: torch.Tensor, peer_k: torch.Tensor, + v: torch.Tensor, peer_v: torch.Tensor, + o_stats: list,# peer_o_stats: list, + time_step: int, comm_mode, debug=False) -> torch.Tensor: + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Handles for operations that actually need to be wait before going to the next iteration. + # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; + all_handles = [] + # KV logic: different than older version, every rank to send/recv its own kv, + # to balance communication. In a balanced communication, every step each rank + # should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when + # time step > 0, should send its own kv and send/recv qo. In the older version, + # rank 0 does not send its kv, and rely on a later rank to pass it, where the + # later rank has to (1) receive kv, send rank 0's kv and send/recv qo. + # Q (load balancing) logic: semantically, this will be "%" world size, so + # the same send/recv rank as KV. Note: Only support even number of machines. + # O (load balancing) logic: rank 0 sends result to rank 7 at time 1. + # It get delayed for one time step, and thus has different maybe_send/recv_rank. + # Use (time_step + 1) to easily convert to synchornize version. + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _debug_send = _fwd_send_volume + _debug_recv = _fwd_recv_volume + + if maybe_send_rank >= seq_world_size: + #send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + #print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") + #q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(q) * q.element_size() + else: + # send kv + #print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") + #kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) + #kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(k) * k.element_size() + _fwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") + #q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + else: + # recv kv + #print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") + #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) + #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + maybe_send_rank_o = seq_rank - (time_step - 1) + maybe_recv_rank_o = seq_rank + (time_step - 1) + if maybe_send_rank_o < 0 and time_step > 1: + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") + #o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(t) * t.element_size() + if maybe_recv_rank_o >= seq_world_size and time_step > 1 : + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") + #o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(t) * t.element_size() + + #reqs = [] + + if debug: + if seq_rank in [0, 8]: + print(f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB") + #return reqs + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs] + +# delta: may be you are using it for your local compute or as a distributed buffer to send to others +# .. Sorry for the bad naming.. +def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, + dv_delta: torch.Tensor, dk_delta_from_peer: torch.Tensor, + dv_delta_from_peer: torch.Tensor, q: torch.Tensor, + peer_q: torch.Tensor, L: torch.Tensor, + peer_L: torch.Tensor, k: torch.Tensor, + peer_k: torch.Tensor, v: torch.Tensor, + peer_v: torch.Tensor, o: torch.Tensor, + peer_o: torch.Tensor, do: torch.Tensor, + peer_do: torch.Tensor, time_step: int, comm_mode, debug=False): + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + all_handles = [] + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if maybe_send_rank >= seq_world_size: + #send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=L, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=o, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=do, peer=maybe_send_rank % seq_world_size, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(q) * q.element_size() + _bwd_send_volume += torch.numel(L) * L.element_size() + _bwd_send_volume += torch.numel(o) * o.element_size() + _bwd_send_volume += torch.numel(do) * do.element_size() + else: + # send kv + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(k) * k.element_size() + _bwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_L, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_o, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_do, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() + _bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size() + _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() + else: + # recv kv + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + # Whether I should update dq, dk and dv after waiting these requests + is_update_dq = False + is_update_dkv = False + + maybe_send_rank_dqkv = seq_rank - (time_step - 1) + maybe_recv_rank_dqkv = seq_rank + (time_step - 1) + + if time_step > 1: + if maybe_send_rank_dqkv < 0: + #print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") + all_handles.append(P2POp(op=isend, tensor=dq_delta, peer=maybe_send_rank_dqkv % seq_world_size, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + #print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank_dqkv, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank_dqkv, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + + if maybe_recv_rank_dqkv >= seq_world_size: + #print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") + all_handles.append(P2POp(op=irecv, tensor=dq_delta, peer=maybe_recv_rank_dqkv % seq_world_size, group=seq_group)) + is_update_dq = True + if debug: + _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + #print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") + all_handles.append(P2POp(op=irecv, tensor=dk_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) + is_update_dkv = True + if debug: + _bwd_recv_volume += torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() + _bwd_recv_volume += torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size() + + # return [], is_update_dq, is_update_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs], is_update_dq, is_update_dkv + +def maybe_send_recv_bwd_last_dkv(dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False): + is_update_last_dkv = False + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + if seq_world_size == 1: return [], is_update_last_dkv + + all_handles = [] + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if time_step == seq_world_size // 2: + maybe_send_rank = seq_rank - time_step + maybe_recv_rank = seq_rank + time_step + + assert (maybe_send_rank >= 0) ^ (maybe_recv_rank < seq_world_size), "R={seq_rank} should be either sending or receiving dkv in the last time step." + + if maybe_send_rank >= 0: + # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + if maybe_recv_rank < seq_world_size: + # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") + all_handles.append(P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() + is_update_last_dkv = True + + # return [], is_update_last_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + + return [all_reqs], is_update_last_dkv + +def print_and_reset_comm_stats(): + seq_rank = get_sequence_parallel_rank() + + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _fwd_send_volume *= 1e-9 + _fwd_recv_volume *= 1e-9 + _bwd_send_volume *= 1e-9 + _bwd_recv_volume *= 1e-9 + + print(f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB.") + _fwd_send_volume = 0 + _fwd_recv_volume = 0 + _bwd_send_volume = 0 + _bwd_recv_volume = 0 + +def launch_async_handles(handles, comm_mode): + global _args + if comm_mode == "nocomm": + #print("skipping communication for ablation") + return [] + if len(handles) > 0: + return dist.batch_isend_irecv(handles) + return [] + +def wait_async_handles(reqs): + if len(reqs) > 0: + for req in reqs: + for r in req: + r.wait() \ No newline at end of file diff --git a/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py new file mode 100644 index 0000000000..d776495bc3 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py @@ -0,0 +1,743 @@ +import os +import math + +from einops import rearrange +import argparse + +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +#from torch.profiler import profile, record_function, ProfilerActivity +import functools +import triton +import triton.language as tl +import time +import numpy as np +from tqdm import tqdm + +try: + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +except: + pass + +from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, + launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, + maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m + m_ptrs = m + off_hz * N_CTX + offs_m + peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load(peer_o_block_ptr) + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16)) + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + m, + l, + O, + L, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + m_ptrs = m + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16)) + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + # Why do I have to change it from 128 64 to 32 32? + BLOCK_M = 32 + BLOCK_N = 32 + + bsz, nh, seq_len, hdim = q.shape + + m = torch.full((bsz * nh, seq_len), fill_value=-float("inf"), device=q.device, dtype=torch.float32) + l = torch.zeros_like(m) + L = torch.zeros_like(m) + o = torch.zeros_like(q) + + grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( + q, k, v, sm_scale, + m, + l, + o, + L, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + o.shape[0], o.shape[1], o.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4) + return q, k, v, o, L + +def _lightseq_backward(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine): + BLOCK = 128 + q, k, v, o, do = [rearrange(_x, 'b h s d -> b s h d').contiguous() for _x in [q, k, v, o, do]] + L = rearrange(L, '(b h) s -> b h s', b=q.shape[0]) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[2] + nkvh = k.shape[2] + is_gqa = (nqh > nkvh) + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ + peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) Immediate wait for abalation") + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + if backward_engine == "flash": + _flash_attn_backward(do, q, k, v, o, L, dq, dk, dv, 0.0, sm_scale, True, (-1,-1), None, False) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + else: + if backward_engine == "flash": + _flash_attn_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + pass + else: + if backward_engine == "flash": + _flash_attn_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) + else: + inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [rearrange(_x, 'b h s d -> b s h d') for _x in [dq, dk, dv]] + return dq, dk, dv + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = 'lightseq' + backward_engine = 'flash' + + q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode) + + ctx.save_for_backward(q, k, v, o, L) + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L = ctx.saved_tensors + sm_scale = ctx.sm_scale + + dq, dk, dv = _lightseq_backward(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine) + return dq, dk, dv, None, None + +attention = _attention.apply + + +#@pytest.mark.parametrize('causal', [False, True]) +#@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #f" {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1,2)) + flash_ref_out = flash_ref_out.transpose(1,2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1,2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1,2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1,2) + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg='provider', + line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} +) for mode in ["all"] for causal in [True]] + +# @triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): + assert mode == "all" #mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: attention(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") + + #return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="benchmark") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + else: + try: + import xformers.ops + from xformers.ops.fmha.common import Inputs, Context + from xformers.ops.fmha import _memory_efficient_attention_backward + from xformers.ops.fmha import cutlass, flash + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [2048, 4096]: + test_op(1, 16, N_CTX, 128, True) + #test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py new file mode 100644 index 0000000000..388ecd4c81 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py @@ -0,0 +1,772 @@ +import os +import math + +from einops import rearrange +import argparse + +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +#from torch.profiler import profile, record_function, ProfilerActivity + +import triton +import triton.language as tl +import time +import numpy as np +from tqdm import tqdm + +try: + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward +except: + pass + +from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, + launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, + maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + seqlen_q_rounded, seqlen_peer_q_rounded, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load(peer_o_block_ptr)#, boundary_check=(0, 1), padding_option='zero') + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) #, boundary_check=(0, 1), padding_option='zero') + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + m, + l, + O, + L, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + seqlen_q_rounded, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + # (TODO): Why float32? + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero') + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option='zero') + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero') + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * seqlen_q_rounded + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + #print("*"*100, dkv.shape, bs, slen, nkvh, n_rep, hdim) + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + #print("-"*100, dkv_reshape.shape, bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + # assert Lq == Lk and Lk == Lv + # assert Lk in {16, 32, 64, 128} + BLOCK_M = 128 + BLOCK_N = 64 + + bsz, nh, unpadded_seq_len, hdim = q.shape + cu_seq_lens = torch.arange(0, (bsz+1) * unpadded_seq_len, unpadded_seq_len, dtype=torch.int32, device=q.device) + max_seqlen = unpadded_seq_len + seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M + + m = torch.full((bsz * nh, seqlen_q_rounded), fill_value=-float("inf"), device=q.device, dtype=torch.float32) + l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.zeros_like(q) + + grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( + q, k, v, sm_scale, + m, + l, + o, + L, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + seqlen_q_rounded, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): +# print(f"t={time_step}: (Comp) R={seq_rank} sync with other - last time: {is_last_time(time_step)}") + seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1] + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + o.shape[0], o.shape[1], o.shape[2], + seqlen_q_rounded, seqlen_peer_q_rounded, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4) + return q, k, v, o, L, cu_seq_lens, max_seqlen + +def _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine, cu_seq_lens, max_seqlen): + BLOCK = 128 + L = rearrange(L[:, :max_seqlen].contiguous(), '(b h) s -> b h s', b=q.shape[0]) + q, k, v, o, do = [rearrange(_x, 'b h s d -> (b s) h d').contiguous() for _x in [q, k, v, o, do]] + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[1] + nkvh = k.shape[1] + is_gqa = (nqh > nkvh) + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ + peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) + if comm_mode == "sync": + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(do, q, k, v, o, L, dq, dk, dv, cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, True, None) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + else: + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + # print(f"BWD t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"BWD t={time_step}: (Comp) R={seq_rank} helps other") + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) + else: + inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [rearrange(_x, '(b s) h d -> b h s d', s=max_seqlen) for _x in [dq, dk, dv]] + return dq, dk, dv + +class _attention_varlen(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = 'lightseq' + backward_engine = 'flash' + + q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode) + + ctx.save_for_backward(q, k, v, o, L, cu_seq_lens) + ctx.max_seqlen = max_seqlen + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L, cu_seq_lens = ctx.saved_tensors + sm_scale = ctx.sm_scale + max_seqlen = ctx.max_seqlen + + dq, dk, dv = _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine, cu_seq_lens, max_seqlen) + return dq, dk, dv, None, None + +dist_attn_varlen = _attention_varlen.apply + + +#@pytest.mark.parametrize('causal', [False, True]) +#@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + rank = dist.get_rank() + world_size = dist.get_world_size() + + + PAD = world_size * 256 + seq_per_rank = (N_CTX-PAD) // world_size + q = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + # DEBUG: mask out + #mask = torch.zeros(Z, H, seq_per_rank * (world_size - 1), D_HEAD).cuda() + #mask_2 = torch.ones(Z, H, seq_per_rank, D_HEAD).cuda() + #mask = torch.cat((mask, mask_2), dim=-2).to(dtype) + #q = mask * q + #k = mask * k + #v = mask * v + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX-PAD, N_CTX-PAD), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f"rank {rank} fails backward dq" #{ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dq} {torch.max(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dq)} rank {rank} fails backward dk" + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #{ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv" #{ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#TODO(High Priority): Investigate why rank 0 tends to have larger numerical difference. +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + #print(q.shape, ref_k.shape, k.shape) + p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + #print("Before reduce", ref_dv.shape) + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) + #print("After reduce", ref_dv.shape) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1,2)) + flash_ref_out = flash_ref_out.transpose(1,2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1,2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1,2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1,2) + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg='provider', + line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} +) for mode in ["all"] for causal in [True]] + +# @triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): + assert mode == "all" #mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: dist_attn_varlen(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") + + #return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="test") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + else: + try: + import xformers.ops + from xformers.ops.fmha.common import Inputs, Context + from xformers.ops.fmha import _memory_efficient_attention_backward + from xformers.ops.fmha import cutlass, flash + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [4096]: + test_op(2, 16, N_CTX, 128, True) + #test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py new file mode 100644 index 0000000000..4f0ab7bfae --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -0,0 +1,609 @@ +""" +Materialization-aware gradient checkpointing monkey patch. +""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast + +from einops import rearrange + +from .lightseq_async_attn import _lightseq_forward, _lightseq_backward +from .async_communication import initialize_distributed, reset_global_memory_buffer + + +# define a global buffer to save flash attention outputs +# it's called global because it saves the outputs for all layers +global_flash_attn_out_buffer = None + +# define a local buffer to save recomputed qkv +# it's called local because it's a temporary buffer which will be updated across layers +local_res_grad_buffer = None + +# hooks for the gradients of residual +global_hooks = [] + +def init_flash_attn_buffers(num_layers): + # update the global buffer according to number of layers + global global_flash_attn_out_buffer + global_flash_attn_out_buffer = [None] * num_layers + +def clean_hook(): + # Remove all hooks in the global buffer + for hook in global_hooks: + hook.remove() + # Clear the global buffer + global_hooks.clear() + +def clear_all_buffers_at_the_end_of_training(): + # call it at the end of training + global lobal_flash_attn_out_buffer + global_flash_attn_out_buffer = None + global local_res_grad_buffer + local_res_grad_buffer = None + clean_hook() + +def save_flash_attn_out_to_global_buffer(idx, out): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = out + +def get_flash_attn_out_from_global_buffer(idx): + global global_flash_attn_out_buffer + return global_flash_attn_out_buffer[idx] + +def free_flash_attn_out_buffer(idx): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = None + +def write_gradient_to_flash_attn_out(idx, grad): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx].grad = grad + +def save_res_grad_hook(grad): + global local_res_grad_buffer + local_res_grad_buffer = grad + +def load_and_add_res_grad_hook(grad): + grad += get_res_grad_from_local_buffer() + +def get_res_grad_from_local_buffer(): + global local_res_grad_buffer + assert local_res_grad_buffer is not None + return local_res_grad_buffer + +class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): + """ Avoid doing twice flash attention forward during checkpointed backward. + args: + hidden_states, # i.e., flash attention output which is saved in global buffer. + attention_mask, + position_ids, + residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. + """ + + @staticmethod + def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.layer_idx = layer_idx + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if i == 0 and ctx.layer_idx != 0: + # flash attention output is saved to the global buffer during forward + ctx.inputs.append(None) + else: + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + with torch.no_grad(): + q, k, v, residual = run_function(*args) + softmax_scale = q.shape[-1] ** (-0.5) + + # lightseq version + _, _, _, out, softmax_lse = _lightseq_forward(q, k, v, True, softmax_scale, comm_mode='lightseq') + rng_state = None + + # save flash attention output to global buffer + save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + tensor_inputs += [softmax_lse] + ctx.softmax_scale = softmax_scale + + ctx.save_for_backward(*tensor_inputs) + + return out, residual + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + tensors, softmax_lse = tensors[:-1], tensors[-1] + + # Fill in inputs with appropriate saved tensors. + # Fill the flash attention output first + if ctx.layer_idx > 0: + # inputs[0] should be flash attention output + inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + # Stop recomputation before flash attention + # It is unecessary to run recomputation for flash attn + q, k, v, residual = ctx.run_function(*detached_inputs) + + # run backward() with only tensor that requires grad + # run flash attention backward first: + # get 'dout' from auto_grad inputs + # get 'out' from global buffer + # get 'qkv' from the recomputed tensors + #dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) + #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) + #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) + out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + # todo get dout + dout = args[0] + + # lightseq version + dq, dk, dv = _lightseq_backward(dout, q, k, v, out, softmax_lse, ctx.softmax_scale, comm_mode='lightseq', backward_engine='flash') + #dqkv = torch.stack([dq, dk, dv]) + + # run backward for the part before flash attention + #qkv.backward(dqkv) + torch.autograd.backward([q, k, v], [dq, dk, dv]) + + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + + # write flash attention output gradients to buffer + if ctx.layer_idx > 0: + write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) + + return (None, None, None) + grads + + +def checkpoint_end_with_flash_attention(function, layer_idx, *args, use_reentrant: bool = True, **kwargs): + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs and use_reentrant: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + return CheckpointFunctionEndWithFlashAttention.apply(function, layer_idx, preserve, *args) + + +class CheckpointFunctionLastModule(torch.autograd.Function): + """ + for the last ffn layer after flash attention, modifications include: + write the gradients wrt flash attention output and residual to the global buffer. + """ + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + + assert torch.is_tensor(args[0]), "assuming the first tensor is the flash attention output" + for i, arg in enumerate(args): + if torch.is_tensor(arg) and i == 0: + # flash attn output has been saved to global buffer + ctx.inputs.append(None) + elif torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + # Fill the flash attention output first + # inputs[0] should be flash attention output + inputs[0] = get_flash_attn_out_from_global_buffer(-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + + # write flash attention output gradients to buffer + write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) + + return (None, None) + grads + +def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs and use_reentrant: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + return CheckpointFunctionLastModule.apply(function, preserve, *args) + + +def llama_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + compute_attn_only: Optional[bool] = False, + compute_ffn_only: Optional[bool] = False, + residual: Optional[bool] = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + assert compute_ffn_only or compute_attn_only + + if compute_attn_only: + residual = hidden_states + + if residual.requires_grad: + # register a hook to add the gradient of residual + # from next checkpoint layer when doing recomputation + hook = residual.register_hook(load_and_add_res_grad_hook) + global_hooks.append(hook) + + hidden_states = self.input_layernorm(hidden_states) + + # Flash Attention + bsz, q_len, _ = hidden_states.size() + try: + query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) + value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) + except: + # old transformers versions don't support num_key_value_heads + query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.self_attn.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + return query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), residual + + elif compute_ffn_only: + hidden_states = self.self_attn.o_proj(rearrange(hidden_states, 'b h s d -> b s (h d)')) + # Need to add residual here to make sure checkpoint is right after attention + if residual.requires_grad: + # save the gradient of residual to the local buffer + # collect the hooks which should be removed after backward to avoid memory leak + hook = residual.register_hook(save_res_grad_hook) + global_hooks.append(hook) + + hidden_states = residual + hidden_states + + # Fully Connected + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + else: + raise AttributeError + + return outputs + + +def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, +): + assert cache_position is None, "cache_position is not supported" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + attention_mask = None + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + try: + logger.warning_once( + "***** Using fast gradient checkpointing... *****" + ) + except: + pass + # initialize the global buffer + init_flash_attn_buffers(len(self.layers)) + + if use_cache: + try: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + except: + pass + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # apply flash-attention friendly gradient checkpointing + if self.gradient_checkpointing: + for idx in range(len(self.layers) + 1): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + def forward_first_attn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, _ = inputs + # None for past_key_value + return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) + return custom_forward + + def forward_ffn_attn_layer(module1, module2): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + layer_outputs = module1(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) + hidden_states = layer_outputs[0] + return module2(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) + return custom_forward + + def forward_last_ffn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) + return custom_forward + + if idx == 0: + layer_outputs = checkpoint_end_with_flash_attention( + forward_first_attn_module(self.layers[0]), + idx, + hidden_states, + attention_mask, + position_ids, + None, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + elif idx == len(self.layers): + layer_outputs = checkpoint_last_module( + forward_last_ffn_module(self.layers[-1]), + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states = layer_outputs[0] + else: + layer_outputs = checkpoint_end_with_flash_attention( + forward_ffn_attn_layer(self.layers[idx-1], self.layers[idx]), + idx, + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + else: + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def apply_dist_flash_attn_monkey_patch_llama(): + initialize_distributed() + transformers.models.llama.modeling_llama.LlamaModel.forward = forward + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py new file mode 100644 index 0000000000..036bbc9aa0 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py @@ -0,0 +1,72 @@ + + +def extract_local(value, rank, world_size, device, dim=1): + value_local = value.chunk(world_size, dim=dim)[rank] + if device == None: + return value_local + return value_local.to(device) + + +def prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_dist_flash_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": local_attention_mask, + "position_ids": local_position_ids, + "labels": local_labels, + } \ No newline at end of file diff --git a/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py new file mode 100644 index 0000000000..a2cba43022 --- /dev/null +++ b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py @@ -0,0 +1,107 @@ +import transformers +from typing import List, Optional, Tuple, Union +import warnings +import torch +import torch.utils.checkpoint +from yunchang.ulysses import UlyssesAttention + +ulysses_attn = UlyssesAttention() + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = ulysses_attn( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_ulysses_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) + + diff --git a/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py b/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py new file mode 100644 index 0000000000..de6a0b50b4 --- /dev/null +++ b/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py @@ -0,0 +1,80 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + dimension_size = value.shape[dim] + sub_seq_length = dimension_size // world_size + + sub_seq_start = rank * sub_seq_length + sub_seq_end = (rank + 1) * sub_seq_length + local_value = value[:, sub_seq_start:sub_seq_end] + if device == None: + return local_value + return local_value.to(device) + + +def prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_ulysses_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": local_attention_mask, + "position_ids": local_position_ids, + "labels": local_labels, + } \ No newline at end of file diff --git a/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py b/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py new file mode 100644 index 0000000000..fb509e0ef2 --- /dev/null +++ b/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py @@ -0,0 +1,94 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformers +import inspect + + +class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): + """ + Saves VRAM by smartly offloading to RAM. + Tiny hit to performance, since we mask the movement via non blocking calls. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + + return output + + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad = True + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + + pass + + +pass + + +def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + assert gradient_checkpointing_kwargs == None + if not self.supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + + gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = ( + "value" in inspect.signature(self._set_gradient_checkpointing).parameters + ) + + if not _is_using_old_format: + self._set_gradient_checkpointing( + enable=True, gradient_checkpointing_func=gradient_checkpointing_func + ) + else: + raise NotImplementedError() + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + +def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): + transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( + new_gradient_checkpointing_enable + ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py new file mode 100644 index 0000000000..4ebcdb7d05 --- /dev/null +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -0,0 +1,113 @@ +import transformers +from typing import List, Optional, Tuple, Union +import warnings +import torch +import torch.utils.checkpoint +from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func + + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_zigzag_ring_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) + + +def apply_zigzag_ring_attn_monkey_patch_mistral(): + transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py new file mode 100644 index 0000000000..24f7e4d467 --- /dev/null +++ b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py @@ -0,0 +1,76 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat( + [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim + ) + if device == None: + return local_value + return local_value.to(device) + + +def prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_zigzag_ring_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": local_attention_mask, + "position_ids": local_position_ids, + "labels": local_labels, + } \ No newline at end of file diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index facbe792ca..5d0f52993d 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -312,6 +312,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) + parallel_mode: Literal["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"] = field( + default="data_parallel", + metadata={"help": "which sequence parallel mode to use."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index c063b214df..8a569afab4 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -10,7 +10,11 @@ from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler - +from torch.utils.data import DataLoader +from transformers.utils import is_datasets_available +from transformers.trainer_utils import seed_worker +import datasets +from torch.nn import CrossEntropyLoss if TYPE_CHECKING: from transformers import ProcessorMixin @@ -130,3 +134,148 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) + +class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + if _is_peft_model(unwrapped_model): + model_name = unwrapped_model.base_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + if self.finetuning_args.parallel_mode== "data_parallel": + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + #print(f"loss={loss}, rank={os.environ['RANK']}") + #time.sleep(60) + else: + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + valid_label_cnt = (labels!=-100).sum(1)[None, :] + valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + n_gpus = valid_label_cnt_gather.shape[0] + valid_label_cnt_all =valid_label_cnt_gather.sum(0) + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + bs = len(shift_labels) + loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) + for b in range(bs): + normalizer=valid_label_cnt_all[b].item() + loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer + #print(f"loss={loss}, rank={os.environ['RANK']}") + #time.sleep(60) + loss = loss.mean()*n_gpus + + return (loss, outputs) if return_outputs else loss + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def get_eval_dataloader(self, eval_dataset) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: + return self.accelerator.prepare(self._eval_dataloader) + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + self._eval_dataloader = eval_dataloader + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + return eval_dataloader + return self.accelerator.prepare(eval_dataloader) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f09b51730b..098396a074 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -5,13 +5,18 @@ from transformers import DataCollatorForSeq2Seq from ...data import get_dataset, split_dataset +from ...data.collator import SeqParallelDataCollator from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics -from .trainer import CustomSeq2SeqTrainer +from .trainer import CustomSeq2SeqTrainer, CustomSeqParallelTrainer + +import torch +import os +from ...easy_context import apply_seq_parallel_monkey_patch if TYPE_CHECKING: @@ -32,6 +37,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama") if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -39,19 +45,25 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = DataCollatorForSeq2Seq( + local_rank = int(os.getenv("LOCAL_RANK")) + world_size = torch.distributed.get_world_size() + print(f"seq_len: {data_args.cutoff_len}") + data_collator = SeqParallelDataCollator( tokenizer=tokenizer, - pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention + pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + seq_algo=finetuning_args.parallel_mode, + rank=torch.distributed.get_rank(), + world_size=world_size, + device=torch.device("cuda", local_rank) ) - # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns # Initialize our Trainer - trainer = CustomSeq2SeqTrainer( + trainer = CustomSeqParallelTrainer( model=model, args=training_args, finetuning_args=finetuning_args, From 7544719497bb34dcb08a0b1426d05fba3344f6bd Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 9 Jul 2024 20:03:11 +0800 Subject: [PATCH 02/31] merge easycontext into llamafactory --- .gitignore | 8 + Llama3-70B-pt-dp.sh | 178 ++++++++++++++++++ Llama3-70B-pt-sp-lora.sh | 178 ++++++++++++++++++ Llama3-70B-pt-sp.sh | 178 ++++++++++++++++++ Llama3-70B-sft.sh | 176 +++++++++++++++++ src/llamafactory/data/__init__.py | 3 +- src/llamafactory/data/collator.py | 50 ++++- src/llamafactory/data/loader.py | 17 +- .../dist_flash_attn/monkey_patch.py | 3 +- .../dist_flash_attn/prepare_input.py | 2 +- src/llamafactory/extras/logging.py | 4 +- src/llamafactory/hparams/parser.py | 2 +- src/llamafactory/model/loader.py | 2 +- src/llamafactory/train/pt/trainer.py | 152 ++++++++++++++- src/llamafactory/train/pt/workflow.py | 54 ++++-- src/llamafactory/train/sft/trainer.py | 5 +- src/llamafactory/train/sft/workflow.py | 34 ++-- src/llamafactory/train/trainer_utils.py | 2 + src/llamafactory/train/tuner.py | 1 + 19 files changed, 994 insertions(+), 55 deletions(-) create mode 100644 Llama3-70B-pt-dp.sh create mode 100644 Llama3-70B-pt-sp-lora.sh create mode 100644 Llama3-70B-pt-sp.sh create mode 100644 Llama3-70B-sft.sh diff --git a/.gitignore b/.gitignore index 0355c66607..8dc3597ec5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +# local directories to ignore +/output/ +*.json +*llmtuner* +/wandb/ +/examples/ +/data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Llama3-70B-pt-dp.sh b/Llama3-70B-pt-dp.sh new file mode 100644 index 0000000000..c9f7de7dd4 --- /dev/null +++ b/Llama3-70B-pt-dp.sh @@ -0,0 +1,178 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export NCCL_P2P_DISABLE=1 +export NCCL_IB_GID_INDEX=3 +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +export CUDA_LAUNCH_BLOCKING=1 +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} +export TRANSFORMERS_VERBOSITY='debug' + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode data_parallel \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-pt-sp-lora.sh b/Llama3-70B-pt-sp-lora.sh new file mode 100644 index 0000000000..f6367490ab --- /dev/null +++ b/Llama3-70B-pt-sp-lora.sh @@ -0,0 +1,178 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export NCCL_P2P_DISABLE=1 +export NCCL_IB_GID_INDEX=3 +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +export CUDA_LAUNCH_BLOCKING=1 +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} +export TRANSFORMERS_VERBOSITY='debug' + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type lora \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path ${DATA_PATH:-"/mnt/zj-gpfs/home/lsy/data/per_source_upsample_32769_common_5b"} \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 99999999 \ +--dataloader_drop_last + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-pt-sp.sh b/Llama3-70B-pt-sp.sh new file mode 100644 index 0000000000..5d52895af0 --- /dev/null +++ b/Llama3-70B-pt-sp.sh @@ -0,0 +1,178 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export NCCL_P2P_DISABLE=1 +export NCCL_IB_GID_INDEX=3 +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +export CUDA_LAUNCH_BLOCKING=1 +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} +export TRANSFORMERS_VERBOSITY='debug' + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-sft.sh b/Llama3-70B-sft.sh new file mode 100644 index 0000000000..045b492c0c --- /dev/null +++ b/Llama3-70B-sft.sh @@ -0,0 +1,176 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-6} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export NCCL_P2P_DISABLE=1 +export NCCL_IB_GID_INDEX=3 +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +export CUDA_LAUNCH_BLOCKING=1 +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset long_sft_32k \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 20000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps 130 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 50000 \ +--dataloader_drop_last + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index b08691d38b..f2ca6b9e9d 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,4 +1,4 @@ -from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SeqParallelDataCollatorForLanguageModeling from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer @@ -7,6 +7,7 @@ __all__ = [ "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", + "SeqParallelDataCollatorForLanguageModeling", "Role", "split_dataset", "get_dataset", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 3abf6a5a1a..92cf80f58e 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -2,7 +2,11 @@ from typing import Any, Dict, Sequence import torch +<<<<<<< Updated upstream from transformers import DataCollatorForSeq2Seq +======= +from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling +>>>>>>> Stashed changes from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from llamafactory.easy_context import prepare_seq_parallel_sft_inputs @@ -84,13 +88,12 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor @dataclass class SeqParallelDataCollator(DataCollatorForSeq2Seq): r""" - Data collator for sequence parallel. + Data collator for sequence parallel in supervised finetune(sft) stage. """ seq_algo: str = "data_parallel" rank: int = 0 world_size: int = 8 device: Optional[Any] = None - def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: batch = super().__call__(features, return_tensors) if self.seq_algo == "data_parallel": @@ -98,12 +101,41 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - batch = prepare_seq_parallel_sft_inputs(self.seq_algo, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=None, - labels=labels, - rank=self.rank, - world_size=self.world_size, + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=self.rank, + world_size=self.world_size, + device=self.device) + return batch + + +@dataclass +class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): + r""" + Data collator for sequence parallel in pretrain(pt) stage. + Reuse the sequence parallel distributing function for sft stage. + """ + seq_algo: str = "data_parallel" + rank: int = 0 + world_size: int = 8 + device: Optional[Any] = None + + def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().__call__(examples) + if self.seq_algo == "data_parallel": + return batch + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=self.rank, + world_size=self.world_size, device=self.device) return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index ba426f8156..8b356e380e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -143,6 +143,18 @@ def get_dataset( if has_tokenized_data(data_args.tokenized_path): logger.warning("Loading dataset from disk will ignore other data arguments.") dataset = load_from_disk(data_args.tokenized_path) + # ---lsy--- + to_remove = [col for col in dataset.column_names if col != "input_ids"] + # import copy + # first_item = copy.deepcopy(dataset[0]['input_ids']) + def update_column(example): + example['input_ids'] = example['input_ids'][:data_args.cutoff_len] + # example['input_ids'] = first_item[:data_args.cutoff_len] + return example + + # # 使用 map 方法添加新列 + dataset = dataset.map(update_column,remove_columns=to_remove) + # ---lsy--- logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) if data_args.streaming: dataset = dataset.to_iterable_dataset() @@ -166,6 +178,7 @@ def get_dataset( data_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) + logger.debug(f"remove_columns:{column_names}") kwargs = {} if not data_args.streaming: kwargs = dict( @@ -175,9 +188,9 @@ def get_dataset( ) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) - + if data_args.tokenized_path is not None: - if training_args.should_save: + if training_args.should_save: dataset.save_to_disk(data_args.tokenized_path) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 4f0ab7bfae..0f9daac2c5 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -15,7 +15,6 @@ from .lightseq_async_attn import _lightseq_forward, _lightseq_backward from .async_communication import initialize_distributed, reset_global_memory_buffer - # define a global buffer to save flash attention outputs # it's called global because it saves the outputs for all layers global_flash_attn_out_buffer = None @@ -497,7 +496,7 @@ def forward( next_decoder_cache = () if use_cache else None # apply flash-attention friendly gradient checkpointing - if self.gradient_checkpointing: + if self.gradient_checkpointing and self.training: for idx in range(len(self.layers) + 1): if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py index 036bbc9aa0..ed081a32e4 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py +++ b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py @@ -69,4 +69,4 @@ def prepare_dist_flash_attn_sft_inputs( "attention_mask": local_attention_mask, "position_ids": local_position_ids, "labels": local_labels, - } \ No newline at end of file + } diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 430b8a48bb..a281ff12d2 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -16,7 +16,7 @@ def __init__(self, output_dir: str) -> None: formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" ) - self.setLevel(logging.INFO) + self.setLevel(logging.DEBUG) self.setFormatter(formatter) os.makedirs(output_dir, exist_ok=True) @@ -53,7 +53,7 @@ def get_logger(name: str) -> logging.Logger: handler.setFormatter(formatter) logger = logging.getLogger(name) - logger.setLevel(logging.INFO) + logger.setLevel(logging.DEBUG) logger.addHandler(handler) return logger diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ec5dd62c59..5e7b9abd1a 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -294,7 +294,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: str(model_args.compute_dtype), ) ) - + logger.info(f"seed is:{training_args.seed}") transformers.set_seed(training_args.seed) return model_args, data_args, training_args, finetuning_args, generating_args diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 697a04e77c..ad4f7dbc95 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, LlamaForCausalLM from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1d96e82f63..18f73d8512 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -5,7 +5,12 @@ from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler - +import torch +from torch.utils.data import DataLoader +from transformers.utils import is_datasets_available +from transformers.trainer_utils import seed_worker +import datasets +from torch.nn import CrossEntropyLoss if TYPE_CHECKING: import torch @@ -49,3 +54,148 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, if self.processor is not None: output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) + +class CustomSeqParallelTrainer(CustomTrainer): + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + if _is_peft_model(unwrapped_model): + model_name = unwrapped_model.base_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + if self.finetuning_args.parallel_mode== "data_parallel": + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + else: + loss_fn = CrossEntropyLoss() + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + + # valid_label_cnt = (labels!=-100).sum(1)[None, :] + # print(f"valid label cnt:{valid_label_cnt}") + # valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + # # valid_label_cnt_gather:[ngpus, bs] + # n_gpus = valid_label_cnt_gather.shape[0] + # valid_label_cnt_all =valid_label_cnt_gather.sum(0) #[bs] + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + bs = len(shift_labels) + loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) + + for b in range(bs): + loss[b]=loss_fn(shift_logits[b], shift_labels[b]) + loss = loss.mean() + + return (loss, outputs) if return_outputs else loss + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def get_eval_dataloader(self, eval_dataset) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: + return self.accelerator.prepare(self._eval_dataloader) + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + self._eval_dataloader = eval_dataloader + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + return eval_dataloader + return self.accelerator.prepare(eval_dataloader) \ No newline at end of file diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 8a6355674d..1d9a36f673 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -5,12 +5,15 @@ from transformers import DataCollatorForLanguageModeling -from ...data import get_dataset, split_dataset +from ...data import get_dataset, split_dataset, SeqParallelDataCollatorForLanguageModeling from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push -from .trainer import CustomTrainer +from .trainer import CustomTrainer, CustomSeqParallelTrainer +import os +import torch +from ...easy_context import apply_seq_parallel_monkey_patch if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -29,10 +32,33 @@ def run_pt( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama") + + # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"seq_len: {data_args.cutoff_len}") + + data_collator = SeqParallelDataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + seq_algo=finetuning_args.parallel_mode, + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + device=torch.device("cuda", local_rank) + ) # Initialize our Trainer - trainer = CustomTrainer( + # trainer = CustomTrainer( + # model=model, + # args=training_args, + # finetuning_args=finetuning_args, + # data_collator=data_collator, + # callbacks=callbacks, + # **tokenizer_module, + # **split_dataset(dataset, data_args, training_args), + # ) + + trainer = CustomSeqParallelTrainer( model=model, args=training_args, finetuning_args=finetuning_args, @@ -51,18 +77,18 @@ def run_pt( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - + # Evaluation - if training_args.do_eval: - metrics = trainer.evaluate(metric_key_prefix="eval") - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") + # if training_args.do_eval: + # metrics = trainer.evaluate(metric_key_prefix="eval") + # try: + # perplexity = math.exp(metrics["eval_loss"]) + # except OverflowError: + # perplexity = float("inf") - metrics["perplexity"] = perplexity - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + # metrics["perplexity"] = perplexity + # trainer.log_metrics("eval", metrics) + # trainer.save_metrics("eval", metrics) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 8a569afab4..021b324d24 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -136,7 +136,6 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: writer.write("\n".join(res)) class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -189,8 +188,6 @@ def compute_loss(self, model, inputs, return_outputs=False): for b in range(bs): normalizer=valid_label_cnt_all[b].item() loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer - #print(f"loss={loss}, rank={os.environ['RANK']}") - #time.sleep(60) loss = loss.mean()*n_gpus return (loss, outputs) if return_outputs else loss @@ -230,7 +227,7 @@ def get_train_dataloader(self) -> DataLoader: if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - + def get_eval_dataloader(self, eval_dataset) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 098396a074..5fc3bb5aaa 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -86,26 +86,26 @@ def run_sft( trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) - trainer.save_state() + # trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - # Evaluation - if training_args.do_eval: - metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) - if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled - metrics.pop("eval_loss", None) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - # Predict - if training_args.do_predict: - predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) - if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled - predict_results.metrics.pop("predict_loss", None) - trainer.log_metrics("predict", predict_results.metrics) - trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results) + # # Evaluation + # if training_args.do_eval: + # metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + # if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + # metrics.pop("eval_loss", None) + # trainer.log_metrics("eval", metrics) + # trainer.save_metrics("eval", metrics) + + # # Predict + # if training_args.do_predict: + # predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + # if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + # predict_results.metrics.pop("predict_loss", None) + # trainer.log_metrics("predict", predict_results.metrics) + # trainer.save_metrics("predict", predict_results.metrics) + # trainer.save_predictions(predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 48944a630a..d3cf37dcc1 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -400,6 +400,8 @@ def get_batch_logps( labels = labels[:, 1:].clone() logits = logits[:, :-1, :] + import os + print(f"---debug---rank:{os.environ['RANK']}, logits is:{logits},local rank:{os.getenv('LOCAL_RANK')}") loss_mask = labels != label_pad_token_id labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index eed875e92a..8b863e1214 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -15,6 +15,7 @@ from .rm import run_rm from .sft import run_sft +import os if TYPE_CHECKING: from transformers import TrainerCallback From 49bd709da2cedcff95c29368227cc59cbbbac0b4 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 9 Jul 2024 20:17:24 +0800 Subject: [PATCH 03/31] remove useless line --- src/llamafactory/data/collator.py | 4 ---- src/llamafactory/train/sft/trainer.py | 2 +- src/llamafactory/train/trainer_utils.py | 2 -- src/llamafactory/train/tuner.py | 1 - 4 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 92cf80f58e..cf6a87c568 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -2,11 +2,7 @@ from typing import Any, Dict, Sequence import torch -<<<<<<< Updated upstream -from transformers import DataCollatorForSeq2Seq -======= from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling ->>>>>>> Stashed changes from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from llamafactory.easy_context import prepare_seq_parallel_sft_inputs diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 021b324d24..b7130564d2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -227,7 +227,7 @@ def get_train_dataloader(self) -> DataLoader: if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - + def get_eval_dataloader(self, eval_dataset) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index d3cf37dcc1..48944a630a 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -400,8 +400,6 @@ def get_batch_logps( labels = labels[:, 1:].clone() logits = logits[:, :-1, :] - import os - print(f"---debug---rank:{os.environ['RANK']}, logits is:{logits},local rank:{os.getenv('LOCAL_RANK')}") loss_mask = labels != label_pad_token_id labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 8b863e1214..eed875e92a 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -15,7 +15,6 @@ from .rm import run_rm from .sft import run_sft -import os if TYPE_CHECKING: from transformers import TrainerCallback From 751c6a15e2665b5c84582a5c4061e0ca08272599 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Wed, 10 Jul 2024 17:10:58 +0800 Subject: [PATCH 04/31] hybrid dp & sp for sft --- src/llamafactory/data/collator.py | 51 +++++++++++++------ src/llamafactory/easy_context/__init__.py | 4 +- .../dist_flash_attn/async_communication.py | 7 +-- .../dist_flash_attn/monkey_patch.py | 4 +- src/llamafactory/hparams/finetuning_args.py | 6 +++ src/llamafactory/train/sft/trainer.py | 14 ++++- src/llamafactory/train/sft/workflow.py | 3 +- 7 files changed, 63 insertions(+), 26 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index cf6a87c568..db7c7f5c11 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -86,10 +86,12 @@ class SeqParallelDataCollator(DataCollatorForSeq2Seq): r""" Data collator for sequence parallel in supervised finetune(sft) stage. """ - seq_algo: str = "data_parallel" + seq_algo: str = "data_parallel", + seq_parallel_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None + def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: batch = super().__call__(features, return_tensors) if self.seq_algo == "data_parallel": @@ -97,13 +99,21 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - batch = prepare_seq_parallel_sft_inputs(self.seq_algo, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=None, - labels=labels, - rank=self.rank, - world_size=self.world_size, + if self.seq_parallel_size != -1: + dp_rank = self.rank // self.seq_parallel_size + bs = len(input_ids) + data_group_size = self.world_size // self.seq_parallel_size + group_bs = bs // data_group_size + input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] + attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] + labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=self.rank, + world_size=self.world_size, device=self.device) return batch @@ -115,23 +125,32 @@ class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling Reuse the sequence parallel distributing function for sft stage. """ seq_algo: str = "data_parallel" + seq_parallel_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None - + def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().__call__(examples) if self.seq_algo == "data_parallel": return batch + if self.seq_parallel_size != -1: + dp_rank = self.rank // self.seq_parallel_size + bs = len(input_ids) + data_group_size = self.world_size // self.seq_parallel_size + group_bs = bs // data_group_size + input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] + attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] + labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - batch = prepare_seq_parallel_sft_inputs(self.seq_algo, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=None, - labels=labels, - rank=self.rank, - world_size=self.world_size, + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=self.rank, + world_size=self.world_size, device=self.device) return batch diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index 687c018de2..bef8ecab74 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -64,7 +64,7 @@ def prepare_seq_parallel_sft_inputs( raise ValueError(f"Invalid seq_algo: {seq_algo}") def apply_seq_parallel_monkey_patch( - seq_algo, model + seq_algo, model,seq_parallel_size=None ): assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" assert model in ["llama", "mistral"], f"Invalid model: {model}" @@ -75,7 +75,7 @@ def apply_seq_parallel_monkey_patch( elif seq_algo == "zigzag_ring_attn" and model == "mistral": apply_zigzag_ring_attn_monkey_patch_mistral() elif seq_algo == "dist_flash_attn" and model == "llama": - apply_dist_flash_attn_monkey_patch_llama() + apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=seq_parallel_size) elif seq_algo == "ulysses_attn" and model == "llama": apply_ulysses_attn_monkey_patch_llama() else: diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py index 610080ea3b..8055737adb 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -39,7 +39,7 @@ _bwd_send_volume = 0 _bwd_recv_volume = 0 -def initialize_distributed(): +def initialize_distributed(sequence_parallel_size=None): if dist.is_initialized(): if dist.get_rank() == 0: print( @@ -55,12 +55,13 @@ def initialize_distributed(): global_world_size = dist.get_world_size() torch.cuda.set_device(dist.get_rank() % local_world_size) - _initialize_sequence_parallel() + _initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size) # create_nccl_communicators() def _initialize_sequence_parallel(sequence_parallel_size=None): # Get world size and rank. Ensure some consistencies. - assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." + # assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." + print(f"sequence_parallel_size is {sequence_parallel_size}") assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 0f9daac2c5..63ba6f4973 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -602,7 +602,7 @@ def custom_forward(*inputs): ) -def apply_dist_flash_attn_monkey_patch_llama(): - initialize_distributed() +def apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=None): + initialize_distributed(sequence_parallel_size=seq_parallel_size) transformers.models.llama.modeling_llama.LlamaModel.forward = forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 5d0f52993d..b636afdf37 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -316,6 +316,12 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default="data_parallel", metadata={"help": "which sequence parallel mode to use."}, ) + seq_parallel_size: int = field( + default=-1, + metadata={ + "help": "used for use seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" + } + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index b7130564d2..e1349eacdd 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -171,8 +171,6 @@ def compute_loss(self, model, inputs, return_outputs=False): # We don't use .loss here since the model may return tuples instead of ModelOutput. if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - #print(f"loss={loss}, rank={os.environ['RANK']}") - #time.sleep(60) else: loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") @@ -225,6 +223,12 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + seq_parallel_size = self.finetuning_args.seq_parallel_size + if seq_parallel_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" + data_parallel_size = world_size // seq_parallel_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @@ -274,5 +278,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + seq_parallel_size = self.args.seq_parallel_size + if seq_parallel_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" + data_parallel_size = world_size // seq_parallel_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size return eval_dataloader return self.accelerator.prepare(eval_dataloader) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5fc3bb5aaa..c3a082ad38 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -37,7 +37,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama") + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", seq_parallel_size=finetuning_args.seq_parallel_size) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -53,6 +53,7 @@ def run_sft( pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, seq_algo=finetuning_args.parallel_mode, + seq_parallel_size=finetuning_args.seq_parallel_size, rank=torch.distributed.get_rank(), world_size=world_size, device=torch.device("cuda", local_rank) From ba4e7686dcd9a2b3c79a85cc2d0e46587913c6c6 Mon Sep 17 00:00:00 2001 From: qianhao <475483052@qq.com> Date: Wed, 10 Jul 2024 09:39:48 +0000 Subject: [PATCH 05/31] add sp_group_size params for launch shell --- Llama3-70B.sh | 2 ++ Llama3-8B.sh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index 1101716720..b2a54ad7fb 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -5,6 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} +SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -29,6 +30,7 @@ src/train.py \ --do_train \ --finetuning_type full \ --parallel_mode dist_flash_attn \ +--seq_parallel_size ${SP_GROUP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset long_sft_128k \ --template llama3 \ diff --git a/Llama3-8B.sh b/Llama3-8B.sh index b746698d47..46c081b412 100644 --- a/Llama3-8B.sh +++ b/Llama3-8B.sh @@ -5,6 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-1024} +SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -30,6 +31,7 @@ src/train.py \ --finetuning_type full \ --lora_target all \ --parallel_mode dist_flash_attn \ +--seq_parallel_size ${SP_GROUP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset alpaca_en \ --template llama3 \ From b4c8d2325b0bdf24699d64b9be75a9b3795931fe Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Thu, 11 Jul 2024 17:23:00 +0800 Subject: [PATCH 06/31] fix sp bug in hybrid sp&dp mode --- .../dist_flash_attn/async_communication.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py index 8055737adb..a70f5f0fa7 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist -from torch.distributed import batch_isend_irecv, P2POp, isend, irecv +from torch.distributed import batch_isend_irecv, P2POp, isend, irecv, get_process_group_ranks # Sequence parallel group that the current rank belongs to. _SEQUENCE_PARALLEL_GROUP = None @@ -61,9 +61,9 @@ def initialize_distributed(sequence_parallel_size=None): def _initialize_sequence_parallel(sequence_parallel_size=None): # Get world size and rank. Ensure some consistencies. # assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." - print(f"sequence_parallel_size is {sequence_parallel_size}") assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") if sequence_parallel_size is None: sequence_parallel_size = world_size @@ -266,6 +266,7 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, seq_group = get_sequence_parallel_group() seq_rank = get_sequence_parallel_rank() seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] # Handles for operations that actually need to be wait before going to the next iteration. # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; @@ -294,7 +295,7 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, if time_step < (seq_world_size // 2 - 1): #print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") #q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) if debug: _fwd_send_volume += torch.numel(q) * q.element_size() else: @@ -302,8 +303,8 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, #print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") #kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) #kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank + seq_offset, group=seq_group)) if debug: _fwd_send_volume += torch.numel(k) * k.element_size() _fwd_send_volume += torch.numel(v) * v.element_size() @@ -313,7 +314,7 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, if time_step < (seq_world_size // 2 - 1): # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") #q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) if debug: _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() else: @@ -321,8 +322,8 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, #print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank + seq_offset, group=seq_group)) if debug: _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() @@ -333,14 +334,14 @@ def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, for t in o_stats: # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") #o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size + seq_offset, group=seq_group)) if debug: _fwd_send_volume += torch.numel(t) * t.element_size() if maybe_recv_rank_o >= seq_world_size and time_step > 1 : for t in o_stats: # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") #o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size + seq_offset, group=seq_group)) if debug: _fwd_recv_volume += torch.numel(t) * t.element_size() @@ -368,6 +369,7 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, seq_group = get_sequence_parallel_group() seq_rank = get_sequence_parallel_rank() seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] all_handles = [] maybe_send_rank = seq_rank + (time_step + 1) @@ -379,10 +381,10 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, if maybe_send_rank >= seq_world_size: #send q, no one needs to do remote computation in the last time step if time_step < (seq_world_size // 2 - 1): - all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=L, peer=maybe_send_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=o, peer=maybe_send_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=do, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=L, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=o, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=do, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) if debug: _bwd_send_volume += torch.numel(q) * q.element_size() _bwd_send_volume += torch.numel(L) * L.element_size() @@ -390,8 +392,8 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, _bwd_send_volume += torch.numel(do) * do.element_size() else: # send kv - all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank + seq_offset, group=seq_group)) if debug: _bwd_send_volume += torch.numel(k) * k.element_size() _bwd_send_volume += torch.numel(v) * v.element_size() @@ -399,10 +401,10 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, if maybe_recv_rank < 0: # recv q, no one needs to do remote computation in the last time step if time_step < (seq_world_size // 2 - 1): - all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_L, peer=maybe_recv_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_o, peer=maybe_recv_rank % seq_world_size, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_do, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_L, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_o, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_do, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) if debug: _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() @@ -410,8 +412,8 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() else: # recv kv - all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank + seq_offset, group=seq_group)) if debug: _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() @@ -426,27 +428,27 @@ def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, if time_step > 1: if maybe_send_rank_dqkv < 0: #print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") - all_handles.append(P2POp(op=isend, tensor=dq_delta, peer=maybe_send_rank_dqkv % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dq_delta, peer=maybe_send_rank_dqkv % seq_world_size + seq_offset, group=seq_group)) if debug: _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() else: #print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") - all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank_dqkv, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank_dqkv, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank_dqkv + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank_dqkv + seq_offset, group=seq_group)) if debug: _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() if maybe_recv_rank_dqkv >= seq_world_size: #print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") - all_handles.append(P2POp(op=irecv, tensor=dq_delta, peer=maybe_recv_rank_dqkv % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dq_delta, peer=maybe_recv_rank_dqkv % seq_world_size + seq_offset, group=seq_group)) is_update_dq = True if debug: _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() else: #print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") - all_handles.append(P2POp(op=irecv, tensor=dk_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=dv_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dk_delta_from_peer, peer=maybe_recv_rank_dqkv + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta_from_peer, peer=maybe_recv_rank_dqkv + seq_offset, group=seq_group)) is_update_dkv = True if debug: _bwd_recv_volume += torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() @@ -462,6 +464,7 @@ def maybe_send_recv_bwd_last_dkv(dk_delta: torch.Tensor, dv_delta: torch.Tensor, seq_group = get_sequence_parallel_group() seq_rank = get_sequence_parallel_rank() seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] if seq_world_size == 1: return [], is_update_last_dkv @@ -478,15 +481,15 @@ def maybe_send_recv_bwd_last_dkv(dk_delta: torch.Tensor, dv_delta: torch.Tensor, if maybe_send_rank >= 0: # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") - all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)) - all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank + seq_offset, group=seq_group)) if debug: _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() if maybe_recv_rank < seq_world_size: # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") - all_handles.append(P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)) - all_handles.append(P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank + seq_offset, group=seq_group)) if debug: _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() From fccd0c3b2664ab24f62edc90f810fcbb0a40c598 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Thu, 11 Jul 2024 17:23:00 +0800 Subject: [PATCH 07/31] fix sp bug in hybrid sp&dp mode --- Llama3-70B.sh | 4 ++-- src/llamafactory/train/sft/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index b2a54ad7fb..4cb5bb407c 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -39,7 +39,7 @@ src/train.py \ --overwrite_cache \ --preprocessing_num_workers 16 \ --output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ ---logging_steps 10 \ +--logging_steps 1 \ --save_steps 500 \ --plot_loss \ --overwrite_output_dir \ @@ -53,7 +53,7 @@ src/train.py \ --ddp_timeout 180000000 \ --val_size 0.1 \ --eval_strategy steps \ ---eval_steps 2 +--eval_steps 1000 # In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. # rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index e1349eacdd..342bc3bc05 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -278,7 +278,7 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": - seq_parallel_size = self.args.seq_parallel_size + seq_parallel_size = self.finetuning_args.seq_parallel_size if seq_parallel_size != -1: world_size = int(os.environ['WORLD_SIZE']) assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" From c039a82c94acbcd48902d4f930b95c31f7ccfe56 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Thu, 11 Jul 2024 21:23:19 +0800 Subject: [PATCH 08/31] fix bug when sequence_parallel_size=-1 --- .../easy_context/dist_flash_attn/async_communication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py index a70f5f0fa7..e67e88a4f7 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -65,7 +65,7 @@ def _initialize_sequence_parallel(sequence_parallel_size=None): world_size: int = torch.distributed.get_world_size() print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") - if sequence_parallel_size is None: + if sequence_parallel_size is None or sequence_parallel_size == -1: sequence_parallel_size = world_size else: assert world_size % sequence_parallel_size == 0 From b7007ece37c2f9426c8b200bc48d8b56315b9899 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 12 Jul 2024 10:59:53 +0800 Subject: [PATCH 09/31] fix loss normalizer for sp&dp hybrid --- src/llamafactory/train/sft/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 342bc3bc05..a4e65ef94d 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -136,6 +136,7 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: writer.write("\n".join(res)) class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -172,13 +173,18 @@ def compute_loss(self, model, inputs, return_outputs=False): if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] else: + sp_size = self.finetuning_args.seq_parallel_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] valid_label_cnt = (labels!=-100).sum(1)[None, :] valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) n_gpus = valid_label_cnt_gather.shape[0] - valid_label_cnt_all =valid_label_cnt_gather.sum(0) + if sp_size == -1: + sp_size = n_gpus + dp_size = n_gpus // sp_size + dp_rank = self.accelerator.process_index // dp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0) shift_logits = logits.contiguous() shift_labels = labels.contiguous() bs = len(shift_labels) From 5544a757e63aba69053181d0598c22525e494eb0 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 12 Jul 2024 14:34:44 +0800 Subject: [PATCH 10/31] fix bug --- src/llamafactory/data/collator.py | 16 ++++++++++++---- src/llamafactory/train/sft/trainer.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index db7c7f5c11..8a0ae59c23 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -99,8 +99,12 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] + seq_rank = self.rank + seq_worlds_size = self.world_size if self.seq_parallel_size != -1: dp_rank = self.rank // self.seq_parallel_size + seq_rank = self.rank % self.seq_parallel_size + seq_worlds_size = self.seq_parallel_size bs = len(input_ids) data_group_size = self.world_size // self.seq_parallel_size group_bs = bs // data_group_size @@ -112,8 +116,8 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D attention_mask=attention_mask, position_ids=None, labels=labels, - rank=self.rank, - world_size=self.world_size, + rank=seq_rank, + world_size=seq_worlds_size, device=self.device) return batch @@ -134,8 +138,12 @@ def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dic batch = super().__call__(examples) if self.seq_algo == "data_parallel": return batch + seq_rank = self.rank + seq_worlds_size = self.world_size if self.seq_parallel_size != -1: dp_rank = self.rank // self.seq_parallel_size + seq_rank = self.rank % self.seq_parallel_size + seq_worlds_size = self.seq_parallel_size bs = len(input_ids) data_group_size = self.world_size // self.seq_parallel_size group_bs = bs // data_group_size @@ -150,7 +158,7 @@ def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dic attention_mask=attention_mask, position_ids=None, labels=labels, - rank=self.rank, - world_size=self.world_size, + rank=seq_rank, + world_size=seq_worlds_size, device=self.device) return batch diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index a4e65ef94d..7ae2f8f857 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -192,7 +192,7 @@ def compute_loss(self, model, inputs, return_outputs=False): for b in range(bs): normalizer=valid_label_cnt_all[b].item() loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer - loss = loss.mean()*n_gpus + loss = loss.mean()*sp_size return (loss, outputs) if return_outputs else loss From 37d097a29707b2b87f33fe4a986d9c6de11cfcfa Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 12 Jul 2024 16:29:59 +0800 Subject: [PATCH 11/31] fix bug --- src/llamafactory/train/sft/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 7ae2f8f857..2abc5e8e40 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -183,8 +183,8 @@ def compute_loss(self, model, inputs, return_outputs=False): if sp_size == -1: sp_size = n_gpus dp_size = n_gpus // sp_size - dp_rank = self.accelerator.process_index // dp_size - valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0) + dp_rank = self.accelerator.process_index // sp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() shift_logits = logits.contiguous() shift_labels = labels.contiguous() bs = len(shift_labels) From d5e513528ecee20a5eaf74bafa8dce451f401f80 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 15:39:24 +0800 Subject: [PATCH 12/31] rename variables --- Llama3-70B.sh | 4 +- Llama3-8B.sh | 4 +- src/llamafactory/data/collator.py | 48 +++++++++---------- src/llamafactory/easy_context/__init__.py | 4 +- .../dist_flash_attn/async_communication.py | 4 +- .../dist_flash_attn/monkey_patch.py | 4 +- src/llamafactory/hparams/finetuning_args.py | 4 +- src/llamafactory/train/sft/trainer.py | 23 ++++----- src/llamafactory/train/sft/workflow.py | 6 +-- 9 files changed, 51 insertions(+), 50 deletions(-) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index 4cb5bb407c..72fe710ea2 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -5,7 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} -SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} +SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -30,7 +30,7 @@ src/train.py \ --do_train \ --finetuning_type full \ --parallel_mode dist_flash_attn \ ---seq_parallel_size ${SP_GROUP_SIZE} \ +--sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset long_sft_128k \ --template llama3 \ diff --git a/Llama3-8B.sh b/Llama3-8B.sh index 46c081b412..0e67d291e9 100644 --- a/Llama3-8B.sh +++ b/Llama3-8B.sh @@ -5,7 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-1024} -SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} +SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -31,7 +31,7 @@ src/train.py \ --finetuning_type full \ --lora_target all \ --parallel_mode dist_flash_attn \ ---seq_parallel_size ${SP_GROUP_SIZE} \ +--sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset alpaca_en \ --template llama3 \ diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 8a0ae59c23..c38dfee1c9 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -87,7 +87,7 @@ class SeqParallelDataCollator(DataCollatorForSeq2Seq): Data collator for sequence parallel in supervised finetune(sft) stage. """ seq_algo: str = "data_parallel", - seq_parallel_size: int = -1 + sp_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None @@ -99,15 +99,15 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - seq_rank = self.rank - seq_worlds_size = self.world_size - if self.seq_parallel_size != -1: - dp_rank = self.rank // self.seq_parallel_size - seq_rank = self.rank % self.seq_parallel_size - seq_worlds_size = self.seq_parallel_size + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size bs = len(input_ids) - data_group_size = self.world_size // self.seq_parallel_size - group_bs = bs // data_group_size + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] @@ -116,8 +116,8 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D attention_mask=attention_mask, position_ids=None, labels=labels, - rank=seq_rank, - world_size=seq_worlds_size, + rank=sp_rank, + world_size=world_size, device=self.device) return batch @@ -138,27 +138,27 @@ def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dic batch = super().__call__(examples) if self.seq_algo == "data_parallel": return batch - seq_rank = self.rank - seq_worlds_size = self.world_size - if self.seq_parallel_size != -1: - dp_rank = self.rank // self.seq_parallel_size - seq_rank = self.rank % self.seq_parallel_size - seq_worlds_size = self.seq_parallel_size + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size bs = len(input_ids) - data_group_size = self.world_size // self.seq_parallel_size - group_bs = bs // data_group_size + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - labels = batch["labels"] batch = prepare_seq_parallel_sft_inputs(self.seq_algo, input_ids=input_ids, attention_mask=attention_mask, position_ids=None, labels=labels, - rank=seq_rank, - world_size=seq_worlds_size, + rank=sp_rank, + world_size=world_size, device=self.device) return batch diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index bef8ecab74..8f72c15786 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -64,7 +64,7 @@ def prepare_seq_parallel_sft_inputs( raise ValueError(f"Invalid seq_algo: {seq_algo}") def apply_seq_parallel_monkey_patch( - seq_algo, model,seq_parallel_size=None + seq_algo, model, sp_size=None ): assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" assert model in ["llama", "mistral"], f"Invalid model: {model}" @@ -75,7 +75,7 @@ def apply_seq_parallel_monkey_patch( elif seq_algo == "zigzag_ring_attn" and model == "mistral": apply_zigzag_ring_attn_monkey_patch_mistral() elif seq_algo == "dist_flash_attn" and model == "llama": - apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=seq_parallel_size) + apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size) elif seq_algo == "ulysses_attn" and model == "llama": apply_ulysses_attn_monkey_patch_llama() else: diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py index e67e88a4f7..68b35b5ae6 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -39,7 +39,7 @@ _bwd_send_volume = 0 _bwd_recv_volume = 0 -def initialize_distributed(sequence_parallel_size=None): +def initialize_distributed(sp_size=None): if dist.is_initialized(): if dist.get_rank() == 0: print( @@ -55,7 +55,7 @@ def initialize_distributed(sequence_parallel_size=None): global_world_size = dist.get_world_size() torch.cuda.set_device(dist.get_rank() % local_world_size) - _initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size) + _initialize_sequence_parallel(sp_size) # create_nccl_communicators() def _initialize_sequence_parallel(sequence_parallel_size=None): diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 63ba6f4973..317b8b2748 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -602,7 +602,7 @@ def custom_forward(*inputs): ) -def apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=None): - initialize_distributed(sequence_parallel_size=seq_parallel_size) +def apply_dist_flash_attn_monkey_patch_llama(sp_size=None): + initialize_distributed(sp_size=sp_size) transformers.models.llama.modeling_llama.LlamaModel.forward = forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b636afdf37..eddbf3217f 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -316,10 +316,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default="data_parallel", metadata={"help": "which sequence parallel mode to use."}, ) - seq_parallel_size: int = field( + sp_size: int = field( default=-1, metadata={ - "help": "used for use seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" + "help": "allow using seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" } ) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 2abc5e8e40..6b1627f93c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -136,13 +136,14 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: writer.write("\n".join(res)) class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: @@ -229,12 +230,12 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": - seq_parallel_size = self.finetuning_args.seq_parallel_size - if seq_parallel_size != -1: + sp_size = self.finetuning_args.sp_size + if sp_size != -1: world_size = int(os.environ['WORLD_SIZE']) - assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" - data_parallel_size = world_size // seq_parallel_size - dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @@ -284,11 +285,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": - seq_parallel_size = self.finetuning_args.seq_parallel_size - if seq_parallel_size != -1: + sp_size = self.finetuning_args.sp_size + if sp_size != -1: world_size = int(os.environ['WORLD_SIZE']) - assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" - data_parallel_size = world_size // seq_parallel_size - dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return eval_dataloader return self.accelerator.prepare(eval_dataloader) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index c3a082ad38..e570d3ef71 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -37,7 +37,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", seq_parallel_size=finetuning_args.seq_parallel_size) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -53,7 +53,7 @@ def run_sft( pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, seq_algo=finetuning_args.parallel_mode, - seq_parallel_size=finetuning_args.seq_parallel_size, + sp_size=finetuning_args.sp_size, rank=torch.distributed.get_rank(), world_size=world_size, device=torch.device("cuda", local_rank) @@ -87,7 +87,7 @@ def run_sft( trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) - # trainer.save_state() + trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) From d31c8f764c7ac1f4e2c26f371b0ea2d4adfd872c Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 15:49:58 +0800 Subject: [PATCH 13/31] fix 70b launch shell --- Llama3-70B.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index 72fe710ea2..b139933734 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -9,6 +9,7 @@ SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true +export NCCL_DEBUG=WARN echo ${RANK}/$[WORLD_SIZE] if [ ${MASTER_ADDR} == 'localhost' ]; then export MASTER_ADDR=`hostname -i` @@ -35,7 +36,7 @@ src/train.py \ --dataset long_sft_128k \ --template llama3 \ --cutoff_len ${SEQ_LEN} \ ---max_samples 1000 \ +--max_steps 10 \ --overwrite_cache \ --preprocessing_num_workers 16 \ --output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ @@ -46,7 +47,7 @@ src/train.py \ --per_device_train_batch_size ${BATCH_SIZE} \ --gradient_accumulation_steps 4 \ --learning_rate 2e-5 \ ---num_train_epochs 3.0 \ +--num_train_epochs 1.0 \ --lr_scheduler_type cosine \ --warmup_ratio 0.1 \ --bf16 \ From 4a4ea30960a10865c1afe0849b0d3a18a3aec7ff Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 16:08:37 +0800 Subject: [PATCH 14/31] fix bug --- src/llamafactory/train/sft/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 6b1627f93c..37de8b8ea2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -174,7 +174,7 @@ def compute_loss(self, model, inputs, return_outputs=False): if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] else: - sp_size = self.finetuning_args.seq_parallel_size + sp_size = self.finetuning_args.sp_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] From 34f70cec659d6a508d2d9172442247116be860bd Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 19:08:29 +0800 Subject: [PATCH 15/31] add dp&sp hybrid for cpt --- src/llamafactory/data/collator.py | 2 +- src/llamafactory/train/pt/trainer.py | 14 ++++++++++++++ src/llamafactory/train/pt/workflow.py | 3 ++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index c38dfee1c9..29023f25f9 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -129,7 +129,7 @@ class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling Reuse the sequence parallel distributing function for sft stage. """ seq_algo: str = "data_parallel" - seq_parallel_size: int = -1 + sp_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 18f73d8512..6a34b27238 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -11,6 +11,7 @@ from transformers.trainer_utils import seed_worker import datasets from torch.nn import CrossEntropyLoss +import os if TYPE_CHECKING: import torch @@ -62,6 +63,7 @@ def compute_loss(self, model, inputs, return_outputs=False): Subclass and override for custom behavior. """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: @@ -148,6 +150,12 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @@ -197,5 +205,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return eval_dataloader return self.accelerator.prepare(eval_dataloader) \ No newline at end of file diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 1d9a36f673..fc33acffae 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -32,7 +32,7 @@ def run_pt( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama") + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) local_rank = int(os.getenv("LOCAL_RANK")) @@ -42,6 +42,7 @@ def run_pt( tokenizer=tokenizer, mlm=False, seq_algo=finetuning_args.parallel_mode, + sp_size=finetuning_args.sp_size, rank=torch.distributed.get_rank(), world_size=torch.distributed.get_world_size(), device=torch.device("cuda", local_rank) From 70bd600d8c9a10bd4e23fc06ed0190dfba21328e Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 19:39:25 +0800 Subject: [PATCH 16/31] add cpt test launch shell --- Llama3-70B-pt-sp-test.sh | 180 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 Llama3-70B-pt-sp-test.sh diff --git a/Llama3-70B-pt-sp-test.sh b/Llama3-70B-pt-sp-test.sh new file mode 100644 index 0000000000..f9a2b836b6 --- /dev/null +++ b/Llama3-70B-pt-sp-test.sh @@ -0,0 +1,180 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +SP_SIZE=${SP_SIZE:-8} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +# export NCCL_P2P_DISABLE=1 +# export NCCL_IB_GID_INDEX=3 +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +# export CUDA_LAUNCH_BLOCKING=1 +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} +export TRANSFORMERS_VERBOSITY='debug' + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--sp_size ${SP_SIZE} \ +--max_steps 10 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last + +# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. +# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seed 2022 \ +# --wandb EasyContext \ +# --max-train-steps 1000 \ +# --learning-rate 2e-5 \ +# --dataset yaofu/slimpajama-per-source-length-upsample \ +# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +# --seq-length 65536 \ +# --rope-theta 5000000 \ +# --parallel_mode data_parallel + +# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seed 2023 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 10000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seed 2024 \ +# --wandb EasyContext \ +# --max-train-steps 500 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 25000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 4 \ +# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seed 2025 \ +# --wandb EasyContext \ +# --max-train-steps 150 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ +# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ +# --seq-length 256000 \ +# --rope-theta 50000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seed 2026 \ +# --wandb EasyContext \ +# --max-train-steps 300 \ +# --learning-rate 2e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 100000000 \ +# --parallel_mode zigzag_ring_attn + + +# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors + +# accelerate launch \ +# --config_file accelerate_configs/single_node.yaml \ +# train.py \ +# --batch-size 1 \ +# --gradient-accumulate-every 2 \ +# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ +# --seed 2027 \ +# --wandb EasyContext \ +# --max-train-steps 90 \ +# --learning-rate 1e-5 \ +# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ +# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ +# --seq-length 512000 \ +# --rope-theta 250000000 \ +# --parallel_mode zigzag_ring_attn + +# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors + + +### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 From 89c28fb65f822dfaffa18c7aa3110123145c70e0 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 20:30:09 +0800 Subject: [PATCH 17/31] fix compute_loss for cpt --- src/llamafactory/train/pt/trainer.py | 23 ++++++++++++----------- src/llamafactory/train/sft/trainer.py | 1 - 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 6a34b27238..d357caf6cd 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -95,24 +95,25 @@ def compute_loss(self, model, inputs, return_outputs=False): loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] else: - loss_fn = CrossEntropyLoss() + sp_size = self.finetuning_args.sp_size + loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] - - # valid_label_cnt = (labels!=-100).sum(1)[None, :] - # print(f"valid label cnt:{valid_label_cnt}") - # valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) - # # valid_label_cnt_gather:[ngpus, bs] - # n_gpus = valid_label_cnt_gather.shape[0] - # valid_label_cnt_all =valid_label_cnt_gather.sum(0) #[bs] + valid_label_cnt = (labels!=-100).sum(1)[None, :] + valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + n_gpus = valid_label_cnt_gather.shape[0] + if sp_size == -1: + sp_size = n_gpus + dp_rank = self.accelerator.process_index // sp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() shift_logits = logits.contiguous() shift_labels = labels.contiguous() bs = len(shift_labels) loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) - for b in range(bs): - loss[b]=loss_fn(shift_logits[b], shift_labels[b]) - loss = loss.mean() + normalizer=valid_label_cnt_all[b].item() + loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer + loss = loss.mean()*sp_size return (loss, outputs) if return_outputs else loss diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 37de8b8ea2..32301bf3f9 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -183,7 +183,6 @@ def compute_loss(self, model, inputs, return_outputs=False): n_gpus = valid_label_cnt_gather.shape[0] if sp_size == -1: sp_size = n_gpus - dp_size = n_gpus // sp_size dp_rank = self.accelerator.process_index // sp_size valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() shift_logits = logits.contiguous() From 8f43fc1749a7e1d03d31f0b7b6be959566eeda3d Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 30 Jul 2024 16:03:37 +0800 Subject: [PATCH 18/31] arange launch shell --- Llama3-70B-pt-dp.sh | 178 ----------------- Llama3-70B-pt-sp-lora.sh | 178 ----------------- Llama3-70B-pt-sp-test.sh | 180 ------------------ Llama3-70B-pt-sp.sh | 178 ----------------- Llama3-70B-sft.sh | 176 ----------------- Llama3-70B.sh | 172 ----------------- Llama3-8B.sh | 174 ----------------- tianqing_examples/Llama3-70B-pt-dp.sh | 59 ++++++ tianqing_examples/Llama3-70B-pt-sp-lora.sh | 59 ++++++ tianqing_examples/Llama3-70B-pt-sp.sh | 59 ++++++ tianqing_examples/Llama3-70B.sh | 57 ++++++ tianqing_examples/Llama3-8B.sh | 59 ++++++ .../llama3_full_sft_ds3.yaml | 0 13 files changed, 293 insertions(+), 1236 deletions(-) delete mode 100644 Llama3-70B-pt-dp.sh delete mode 100644 Llama3-70B-pt-sp-lora.sh delete mode 100644 Llama3-70B-pt-sp-test.sh delete mode 100644 Llama3-70B-pt-sp.sh delete mode 100644 Llama3-70B-sft.sh delete mode 100644 Llama3-70B.sh delete mode 100644 Llama3-8B.sh create mode 100644 tianqing_examples/Llama3-70B-pt-dp.sh create mode 100644 tianqing_examples/Llama3-70B-pt-sp-lora.sh create mode 100644 tianqing_examples/Llama3-70B-pt-sp.sh create mode 100644 tianqing_examples/Llama3-70B.sh create mode 100644 tianqing_examples/Llama3-8B.sh rename llama3_full_sft_ds3.yaml => tianqing_examples/llama3_full_sft_ds3.yaml (100%) diff --git a/Llama3-70B-pt-dp.sh b/Llama3-70B-pt-dp.sh deleted file mode 100644 index c9f7de7dd4..0000000000 --- a/Llama3-70B-pt-dp.sh +++ /dev/null @@ -1,178 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) -SEQ_LEN=${SEQ_LEN:-32768} -BATCH_SIZE=${BATCH_SIZE:-1} -ACC=${ACC:-4} -SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} -SAVE_STEPS=${SAVE_STEPS:-500} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export NCCL_P2P_DISABLE=1 -export NCCL_IB_GID_INDEX=3 -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -export CUDA_LAUNCH_BLOCKING=1 -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} -export TRANSFORMERS_VERBOSITY='debug' - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage pt \ ---do_train \ ---finetuning_type full \ ---parallel_mode data_parallel \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_samples 1000 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ${SAVE_PATH} \ ---logging_steps 1 \ ---save_steps ${SAVE_STEPS} \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps ${ACC} \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 9999 \ ---dataloader_drop_last - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-pt-sp-lora.sh b/Llama3-70B-pt-sp-lora.sh deleted file mode 100644 index f6367490ab..0000000000 --- a/Llama3-70B-pt-sp-lora.sh +++ /dev/null @@ -1,178 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) -SEQ_LEN=${SEQ_LEN:-32768} -BATCH_SIZE=${BATCH_SIZE:-1} -ACC=${ACC:-4} -SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} -SAVE_STEPS=${SAVE_STEPS:-500} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export NCCL_P2P_DISABLE=1 -export NCCL_IB_GID_INDEX=3 -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -export CUDA_LAUNCH_BLOCKING=1 -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} -export TRANSFORMERS_VERBOSITY='debug' - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage pt \ ---do_train \ ---finetuning_type lora \ ---parallel_mode dist_flash_attn \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---tokenized_path ${DATA_PATH:-"/mnt/zj-gpfs/home/lsy/data/per_source_upsample_32769_common_5b"} \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_samples 1000000 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ${SAVE_PATH} \ ---logging_steps 1 \ ---save_steps ${SAVE_STEPS} \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps ${ACC} \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 99999999 \ ---dataloader_drop_last - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-pt-sp-test.sh b/Llama3-70B-pt-sp-test.sh deleted file mode 100644 index f9a2b836b6..0000000000 --- a/Llama3-70B-pt-sp-test.sh +++ /dev/null @@ -1,180 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) -SEQ_LEN=${SEQ_LEN:-32768} -SP_SIZE=${SP_SIZE:-8} -BATCH_SIZE=${BATCH_SIZE:-1} -ACC=${ACC:-4} -SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} -SAVE_STEPS=${SAVE_STEPS:-500} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -# export NCCL_P2P_DISABLE=1 -# export NCCL_IB_GID_INDEX=3 -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -# export CUDA_LAUNCH_BLOCKING=1 -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} -export TRANSFORMERS_VERBOSITY='debug' - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage pt \ ---do_train \ ---finetuning_type full \ ---parallel_mode dist_flash_attn \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---sp_size ${SP_SIZE} \ ---max_steps 10 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ${SAVE_PATH} \ ---logging_steps 1 \ ---save_steps ${SAVE_STEPS} \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps ${ACC} \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 9999 \ ---dataloader_drop_last - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-pt-sp.sh b/Llama3-70B-pt-sp.sh deleted file mode 100644 index 5d52895af0..0000000000 --- a/Llama3-70B-pt-sp.sh +++ /dev/null @@ -1,178 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) -SEQ_LEN=${SEQ_LEN:-32768} -BATCH_SIZE=${BATCH_SIZE:-1} -ACC=${ACC:-4} -SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} -SAVE_STEPS=${SAVE_STEPS:-500} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export NCCL_P2P_DISABLE=1 -export NCCL_IB_GID_INDEX=3 -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -export CUDA_LAUNCH_BLOCKING=1 -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} -export TRANSFORMERS_VERBOSITY='debug' - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage pt \ ---do_train \ ---finetuning_type full \ ---parallel_mode dist_flash_attn \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_samples 1000 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ${SAVE_PATH} \ ---logging_steps 1 \ ---save_steps ${SAVE_STEPS} \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps ${ACC} \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 9999 \ ---dataloader_drop_last - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B-sft.sh b/Llama3-70B-sft.sh deleted file mode 100644 index 045b492c0c..0000000000 --- a/Llama3-70B-sft.sh +++ /dev/null @@ -1,176 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) -SEQ_LEN=${SEQ_LEN:-32768} -BATCH_SIZE=${BATCH_SIZE:-1} -ACC=${ACC:-6} -SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export NCCL_P2P_DISABLE=1 -export NCCL_IB_GID_INDEX=3 -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -export CUDA_LAUNCH_BLOCKING=1 -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage sft \ ---do_train \ ---finetuning_type full \ ---parallel_mode dist_flash_attn \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---dataset long_sft_32k \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_samples 20000 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ${SAVE_PATH} \ ---logging_steps 1 \ ---save_steps 130 \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps ${ACC} \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 50000 \ ---dataloader_drop_last - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-70B.sh b/Llama3-70B.sh deleted file mode 100644 index b139933734..0000000000 --- a/Llama3-70B.sh +++ /dev/null @@ -1,172 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] -SEQ_LEN=${SEQ_LEN:-32768} -SP_SIZE=${SP_SIZE:-1} -BATCH_SIZE=${BATCH_SIZE:-1} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export WANDB_DISABLED=true -export NCCL_DEBUG=WARN -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage sft \ ---do_train \ ---finetuning_type full \ ---parallel_mode dist_flash_attn \ ---sp_size ${SP_SIZE} \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---dataset long_sft_128k \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_steps 10 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ ---logging_steps 1 \ ---save_steps 500 \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps 4 \ ---learning_rate 2e-5 \ ---num_train_epochs 1.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---eval_steps 1000 - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/Llama3-8B.sh b/Llama3-8B.sh deleted file mode 100644 index 0e67d291e9..0000000000 --- a/Llama3-8B.sh +++ /dev/null @@ -1,174 +0,0 @@ -# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and -# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-8B"} -NGPUS=${NGPUS:-8} -WORLD_SIZE=${WORLD_SIZE:-1} -NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] -SEQ_LEN=${SEQ_LEN:-1024} -SP_SIZE=${SP_SIZE:-1} -BATCH_SIZE=${BATCH_SIZE:-1} -export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' -export WANDB_DISABLED=true -echo ${RANK}/$[WORLD_SIZE] -if [ ${MASTER_ADDR} == 'localhost' ]; then - export MASTER_ADDR=`hostname -i` -fi -echo ${MASTER_ADDR}:${MASTER_PORT} - -accelerate launch \ ---config_file examples/accelerate/ds_multi_nodes.yaml \ ---use_deepspeed \ ---num_machines ${WORLD_SIZE} \ ---num_processes ${NUM_PROCESSES} \ ---main_process_ip ${MASTER_ADDR} \ ---main_process_port ${MASTER_PORT} \ ---machine_rank ${RANK} \ ---rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ -src/train.py \ ---model_name_or_path ${MODEL_DIR} \ ---stage sft \ ---do_train \ ---finetuning_type full \ ---lora_target all \ ---parallel_mode dist_flash_attn \ ---sp_size ${SP_SIZE} \ ---deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---dataset alpaca_en \ ---template llama3 \ ---cutoff_len ${SEQ_LEN} \ ---max_samples 1200 \ ---overwrite_cache \ ---preprocessing_num_workers 16 \ ---output_dir ./output/8B_1K_bs_1_step_1000_lr_2e-5 \ ---logging_steps 1 \ ---save_steps 500 \ ---plot_loss \ ---overwrite_output_dir \ ---per_device_train_batch_size ${BATCH_SIZE} \ ---per_device_eval_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps 4 \ ---learning_rate 2e-5 \ ---num_train_epochs 3.0 \ ---lr_scheduler_type cosine \ ---warmup_ratio 0.1 \ ---bf16 \ ---ddp_timeout 180000000 \ ---val_size 0.1 \ ---eval_strategy steps \ ---dataloader_drop_last \ ---eval_steps 1001 - -# In the saved files, there are model-00001-of-00003.safetensors to model-00001-of-00003.safetensors. Somehow model.safetensors is unnecessary and should be removed. -# rm output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seed 2022 \ -# --wandb EasyContext \ -# --max-train-steps 1000 \ -# --learning-rate 2e-5 \ -# --dataset yaofu/slimpajama-per-source-length-upsample \ -# --model output/7B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ -# --seq-length 65536 \ -# --rope-theta 5000000 \ -# --parallel_mode data_parallel - -# rm output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seed 2023 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_64K_bs_1M_rope_5M_step_1000_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 10000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seed 2024 \ -# --wandb EasyContext \ -# --max-train-steps 500 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_10M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 25000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 4 \ -# --output-dir ./output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seed 2025 \ -# --wandb EasyContext \ -# --max-train-steps 150 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_256K \ -# --model output/7B_0.256M_bs_1M_rope_25M_step_500_lr_2e-5 \ -# --seq-length 256000 \ -# --rope-theta 50000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seed 2026 \ -# --wandb EasyContext \ -# --max-train-steps 300 \ -# --learning-rate 2e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.256M_bs_1M_rope_50M_step_150_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 100000000 \ -# --parallel_mode zigzag_ring_attn - - -# rm output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5/model.safetensors - -# accelerate launch \ -# --config_file accelerate_configs/single_node.yaml \ -# train.py \ -# --batch-size 1 \ -# --gradient-accumulate-every 2 \ -# --output-dir ./output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5 \ -# --seed 2027 \ -# --wandb EasyContext \ -# --max-train-steps 90 \ -# --learning-rate 1e-5 \ -# --dataset PY007/slimpajama_llama_tokenized_upsample_4096_chunk_1M \ -# --model output/7B_0.5M_bs_1M_rope_100M_step_300_lr_2e-5 \ -# --seq-length 512000 \ -# --rope-theta 250000000 \ -# --parallel_mode zigzag_ring_attn - -# rm output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/model.safetensors - - -### Finally we directly set the rope_theta in output/7B_0.5M_bs_1M_rope_250M_step_90_lr_2e-5/config.json to 1,000,000,000 diff --git a/tianqing_examples/Llama3-70B-pt-dp.sh b/tianqing_examples/Llama3-70B-pt-dp.sh new file mode 100644 index 0000000000..25e607a8e7 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-dp.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode data_parallel \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B-pt-sp-lora.sh b/tianqing_examples/Llama3-70B-pt-sp-lora.sh new file mode 100644 index 0000000000..24f0cdf169 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-sp-lora.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type lora \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path ${DATA_PATH:-"/mnt/zj-gpfs/home/lsy/data/per_source_upsample_32769_common_5b"} \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 99999999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B-pt-sp.sh b/tianqing_examples/Llama3-70B-pt-sp.sh new file mode 100644 index 0000000000..955079a889 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-sp.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh new file mode 100644 index 0000000000..df290cfccc --- /dev/null +++ b/tianqing_examples/Llama3-70B.sh @@ -0,0 +1,57 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-32768} +SP_SIZE=${SP_SIZE:-1} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--sp_size ${SP_SIZE} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset long_sft_128k \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_steps 10 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 1.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 1000 diff --git a/tianqing_examples/Llama3-8B.sh b/tianqing_examples/Llama3-8B.sh new file mode 100644 index 0000000000..b93f5823f8 --- /dev/null +++ b/tianqing_examples/Llama3-8B.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-1024} +SP_SIZE=${SP_SIZE:-1} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--lora_target all \ +--parallel_mode dist_flash_attn \ +--sp_size ${SP_SIZE} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset alpaca_en \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1200 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/8B_1K_bs_1_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--per_device_eval_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--dataloader_drop_last \ +--eval_steps 1001 diff --git a/llama3_full_sft_ds3.yaml b/tianqing_examples/llama3_full_sft_ds3.yaml similarity index 100% rename from llama3_full_sft_ds3.yaml rename to tianqing_examples/llama3_full_sft_ds3.yaml From f4f9659c86e9998c3df4ba4605d5667b350a1462 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 2 Aug 2024 11:35:48 +0800 Subject: [PATCH 19/31] free buffer --- .../dist_flash_attn/monkey_patch.py | 70 +++++++++++++++++-- src/llamafactory/train/sft/trainer.py | 5 -- tianqing_examples/Llama3-70B.sh | 2 +- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 317b8b2748..372469f1d0 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -1,20 +1,23 @@ """ Materialization-aware gradient checkpointing monkey patch. """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable import transformers -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast, CausalLMOutputWithPast from einops import rearrange from .lightseq_async_attn import _lightseq_forward, _lightseq_backward from .async_communication import initialize_distributed, reset_global_memory_buffer +import deepspeed as ds +from transformers.cache_utils import Cache + # define a global buffer to save flash attention outputs # it's called global because it saves the outputs for all layers global_flash_attn_out_buffer = None @@ -40,7 +43,7 @@ def clean_hook(): def clear_all_buffers_at_the_end_of_training(): # call it at the end of training - global lobal_flash_attn_out_buffer + global global_flash_attn_out_buffer global_flash_attn_out_buffer = None global local_res_grad_buffer local_res_grad_buffer = None @@ -129,6 +132,7 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # save flash attention output to global buffer save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + ds.runtime.utils.see_memory_usage(f"forward, layer={ctx.layer_idx}", force=True) tensor_inputs += [softmax_lse] ctx.softmax_scale = softmax_scale @@ -202,7 +206,8 @@ def backward(ctx, *args): # write flash attention output gradients to buffer if ctx.layer_idx > 0: write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) - + free_flash_attn_out_buffer(ctx.layer_idx) + ds.runtime.utils.see_memory_usage(f"backward, layer={ctx.layer_idx}", force=True) return (None, None, None) + grads @@ -261,6 +266,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch.no_grad(): outputs = run_function(*args) + ds.runtime.utils.see_memory_usage(f"forward, layer=last", force=True) return outputs @staticmethod @@ -601,8 +607,64 @@ def custom_forward(*inputs): attentions=all_self_attns, ) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + ds.runtime.utils.see_memory_usage(f"forward end", force=True) + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def apply_dist_flash_attn_monkey_patch_llama(sp_size=None): initialize_distributed(sp_size=sp_size) transformers.models.llama.modeling_llama.LlamaModel.forward = forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward + transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = llama_model_forward diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 32301bf3f9..bc4a9c5a17 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -165,11 +165,6 @@ def compute_loss(self, model, inputs, return_outputs=False): else: loss = self.label_smoother(outputs, labels) else: - if isinstance(outputs, dict) and "loss" not in outputs: - raise ValueError( - "The model did not return a loss from the inputs, only the following keys: " - f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." - ) # We don't use .loss here since the model may return tuples instead of ModelOutput. if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh index df290cfccc..f88008de0e 100644 --- a/tianqing_examples/Llama3-70B.sh +++ b/tianqing_examples/Llama3-70B.sh @@ -33,7 +33,7 @@ src/train.py \ --parallel_mode dist_flash_attn \ --sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---dataset long_sft_128k \ +--dataset long_sft_32k \ --template llama3 \ --cutoff_len ${SEQ_LEN} \ --max_steps 10 \ From da7a8cc899653480a55617bf661b7576198866f0 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 9 Aug 2024 11:31:04 +0800 Subject: [PATCH 20/31] use global_buffer to load/unload activation --- .../dist_flash_attn/monkey_patch.py | 99 ++++++++++++++++--- 1 file changed, 84 insertions(+), 15 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 372469f1d0..5606a23f78 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -14,9 +14,12 @@ from .lightseq_async_attn import _lightseq_forward, _lightseq_backward from .async_communication import initialize_distributed, reset_global_memory_buffer - -import deepspeed as ds from transformers.cache_utils import Cache +import pycuda +import pycuda.driver as drv +import pycuda.autoinit +import numpy as np +import time, os # define a global buffer to save flash attention outputs # it's called global because it saves the outputs for all layers @@ -29,6 +32,56 @@ # hooks for the gradients of residual global_hooks = [] +class Singleton(object): + _instance = None + def __new__(class_, *args, **kwargs): + if not isinstance(class_._instance, class_): + class_._instance = object.__new__(class_, *args, **kwargs) + return class_._instance + +class GlobalBufferManager(Singleton): + + def init(self, num_layers, offload_percent, shape, dtype, device): + torch.cuda.empty_cache() + if hasattr(self, 'initialized'): + return + self.layer_num = num_layers + self.gpu_layer_num = int(num_layers * offload_percent) + self.cpu_layer_num = num_layers - self.gpu_layer_num + self.gpu_buffer = [torch.empty(shape, dtype=dtype, device=device) for _ in range(self.gpu_layer_num)] + self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) + self.d2h_stream = drv.Stream() + self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] + self.initialized = True + + def save_flash_attn_out(self, layer_idx, out): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx < self.cpu_layer_num: + drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + drv.memcpy_dtod(self.gpu_buffer[idx].data_ptr(), out.data_ptr(), out.element_size() * out.nelement()) + + def get_flash_attn_out(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx > self.cpu_layer_num: + return self.gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num + self.h2d_streams[idx].synchronize() + return self.gpu_buffer[idx] + + def free_flash_attn_out(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + cpu_layer_idx = layer_idx - self.gpu_layer_num + if cpu_layer_idx < 0: + return + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num + self.gpu_buffer[idx].grad = None + drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + def init_flash_attn_buffers(num_layers): # update the global buffer according to number of layers global global_flash_attn_out_buffer @@ -131,13 +184,17 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): rng_state = None # save flash attention output to global buffer - save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) - ds.runtime.utils.see_memory_usage(f"forward, layer={ctx.layer_idx}", force=True) + # save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + GlobalBufferManager().save_flash_attn_out(ctx.layer_idx, out) tensor_inputs += [softmax_lse] ctx.softmax_scale = softmax_scale ctx.save_for_backward(*tensor_inputs) - + tensor_inputs_ma = 0 + for ti in tensor_inputs: + tensor_inputs_ma += ti.element_size() * ti.nelement() + if int(os.environ['RANK']) == 0: + print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}, tensor_inputs_ma: {tensor_inputs_ma/(1<<30):2f}") return out, residual @staticmethod @@ -157,7 +214,8 @@ def backward(ctx, *args): # Fill the flash attention output first if ctx.layer_idx > 0: # inputs[0] should be flash attention output - inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) + # inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) + inputs[0] = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1) for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -188,7 +246,10 @@ def backward(ctx, *args): #dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) - out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + # out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + out = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx) + if int(os.environ['RANK']) == 0: + print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}") # todo get dout dout = args[0] @@ -205,9 +266,10 @@ def backward(ctx, *args): # write flash attention output gradients to buffer if ctx.layer_idx > 0: - write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) - free_flash_attn_out_buffer(ctx.layer_idx) - ds.runtime.utils.see_memory_usage(f"backward, layer={ctx.layer_idx}", force=True) + # write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) + GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad + # free_flash_attn_out_buffer(ctx.layer_idx) + GlobalBufferManager().free_flash_attn_out(ctx.layer_idx) return (None, None, None) + grads @@ -266,7 +328,6 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch.no_grad(): outputs = run_function(*args) - ds.runtime.utils.see_memory_usage(f"forward, layer=last", force=True) return outputs @staticmethod @@ -284,7 +345,8 @@ def backward(ctx, *args): # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first # inputs[0] should be flash attention output - inputs[0] = get_flash_attn_out_from_global_buffer(-1) + # inputs[0] = get_flash_attn_out_from_global_buffer(-1) + inputs[0] = GlobalBufferManager().get_flash_attn_out(-1) for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -324,7 +386,8 @@ def backward(ctx, *args): for inp in detached_inputs) # write flash attention output gradients to buffer - write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) + # write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) + GlobalBufferManager().get_flash_attn_out(-1).grad = detached_inputs[0].grad return (None, None) + grads @@ -485,7 +548,14 @@ def forward( except: pass # initialize the global buffer - init_flash_attn_buffers(len(self.layers)) + # init_flash_attn_buffers(len(self.layers)) + GlobalBufferManager().init( + self.config.num_hidden_layers, + offload_percent=0.25, + shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads], + dtype=hidden_states.dtype, + device=hidden_states.device + ) if use_cache: try: @@ -654,7 +724,6 @@ def llama_model_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - ds.runtime.utils.see_memory_usage(f"forward end", force=True) return CausalLMOutputWithPast( loss=loss, logits=logits, From 9c0ce85d1ef291cda79da245b147216b476fe8aa Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Wed, 14 Aug 2024 10:15:29 +0800 Subject: [PATCH 21/31] global_buffer add hidden_states and attention_mask --- .../dist_flash_attn/monkey_patch.py | 122 ++++++++++++------ 1 file changed, 86 insertions(+), 36 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 5606a23f78..c12dafc7db 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -48,8 +48,15 @@ def init(self, num_layers, offload_percent, shape, dtype, device): self.layer_num = num_layers self.gpu_layer_num = int(num_layers * offload_percent) self.cpu_layer_num = num_layers - self.gpu_layer_num - self.gpu_buffer = [torch.empty(shape, dtype=dtype, device=device) for _ in range(self.gpu_layer_num)] + self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) + bs, num_heads, seq_len, emb_size = shape + shape_h = [bs, seq_len, num_heads * emb_size] + shape_a = [bs, seq_len] + self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16) + self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16) self.d2h_stream = drv.Stream() self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] self.initialized = True @@ -61,26 +68,60 @@ def save_flash_attn_out(self, layer_idx, out): drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) else: idx = layer_idx - self.cpu_layer_num - drv.memcpy_dtod(self.gpu_buffer[idx].data_ptr(), out.data_ptr(), out.element_size() * out.nelement()) + self.gpu_buffer[idx] = out + + def save_hidden_states(self, layer_idx, *hs): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + hidden_state = hs[0] + position_id = hs[1] + if layer_idx < self.cpu_layer_num: + drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream) + drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.hidden_state_gpu_buffer[idx] = hidden_state + self.position_id_gpu_buffer[idx] = position_id def get_flash_attn_out(self, layer_idx): if layer_idx < 0: layer_idx = self.layer_num + layer_idx - if layer_idx > self.cpu_layer_num: + if layer_idx >= self.cpu_layer_num: return self.gpu_buffer[layer_idx - self.cpu_layer_num] - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num self.h2d_streams[idx].synchronize() return self.gpu_buffer[idx] + + def get_hidden_states(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.h2d_streams[idx].synchronize() + return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] - def free_flash_attn_out(self, layer_idx): + def free_layer_gpu_buffer(self, layer_idx): if layer_idx < 0: layer_idx = self.layer_num + layer_idx + if layer_idx == self.layer_num - 1: + self.d2h_stream.synchronize() cpu_layer_idx = layer_idx - self.gpu_layer_num + if layer_idx >= self.cpu_layer_num: + idx = layer_idx - self.cpu_layer_num + else: + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.gpu_buffer[idx].grad = None if cpu_layer_idx < 0: + self.gpu_buffer[idx] = None + self.hidden_state_gpu_buffer[idx] = None + self.position_id_gpu_buffer[idx] = None return - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num - self.gpu_buffer[idx].grad = None drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + +global_buffer = GlobalBufferManager() def init_flash_attn_buffers(num_layers): # update the global buffer according to number of layers @@ -161,16 +202,24 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] - ctx.tensor_indices = [] + ctx.tensor_indices = {} tensor_inputs = [] + global global_buffer + hidden_state = None + position_ids = None for i, arg in enumerate(args): if i == 0 and ctx.layer_idx != 0: # flash attention output is saved to the global buffer during forward ctx.inputs.append(None) else: if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) + # tensor_inputs.append(arg) + if len(arg.shape) == 3: + hidden_state = arg + ctx.tensor_indices[i] = 'hidden_state' + elif len(arg.shape) == 2: + position_ids = arg + ctx.tensor_indices[i] = 'position_ids' ctx.inputs.append(None) else: ctx.inputs.append(arg) @@ -185,16 +234,13 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # save flash attention output to global buffer # save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) - GlobalBufferManager().save_flash_attn_out(ctx.layer_idx, out) - tensor_inputs += [softmax_lse] + + global_buffer.save_flash_attn_out(ctx.layer_idx, out) + global_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids) + # tensor_inputs += [softmax_lse] ctx.softmax_scale = softmax_scale - ctx.save_for_backward(*tensor_inputs) - tensor_inputs_ma = 0 - for ti in tensor_inputs: - tensor_inputs_ma += ti.element_size() * ti.nelement() - if int(os.environ['RANK']) == 0: - print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}, tensor_inputs_ma: {tensor_inputs_ma/(1<<30):2f}") + ctx.save_for_backward(softmax_lse) return out, residual @staticmethod @@ -207,18 +253,24 @@ def backward(ctx, *args): # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - tensors, softmax_lse = tensors[:-1], tensors[-1] - + # tensors = ctx.saved_tensors + softmax_lse = ctx.saved_tensors[0] + # tensors, softmax_lse = tensors[:-1], tensors[-1] + global global_buffer + hidden_state, position_ids = global_buffer.get_hidden_states(ctx.layer_idx) # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first if ctx.layer_idx > 0: # inputs[0] should be flash attention output # inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) - inputs[0] = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1) - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - + inputs[0] = global_buffer.get_flash_attn_out(ctx.layer_idx-1) + # for i, idx in enumerate(tensor_indices): + # inputs[idx] = tensors[i] + for k, v in tensor_indices.items(): + if v == 'hidden_state': + inputs[k] = hidden_state + if v == 'position_ids': + inputs[k] = position_ids # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. @@ -247,9 +299,8 @@ def backward(ctx, *args): #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) # out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) - out = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx) - if int(os.environ['RANK']) == 0: - print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}") + out = global_buffer.get_flash_attn_out(ctx.layer_idx) + # todo get dout dout = args[0] @@ -263,13 +314,12 @@ def backward(ctx, *args): grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) - # write flash attention output gradients to buffer if ctx.layer_idx > 0: # write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) - GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad + global_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad # free_flash_attn_out_buffer(ctx.layer_idx) - GlobalBufferManager().free_flash_attn_out(ctx.layer_idx) + global_buffer.free_layer_gpu_buffer(ctx.layer_idx) return (None, None, None) + grads @@ -341,12 +391,12 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - + global global_buffer # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first # inputs[0] should be flash attention output # inputs[0] = get_flash_attn_out_from_global_buffer(-1) - inputs[0] = GlobalBufferManager().get_flash_attn_out(-1) + inputs[0] = global_buffer.get_flash_attn_out(-1) for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -384,10 +434,9 @@ def backward(ctx, *args): torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) - # write flash attention output gradients to buffer # write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) - GlobalBufferManager().get_flash_attn_out(-1).grad = detached_inputs[0].grad + global_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad return (None, None) + grads @@ -549,7 +598,8 @@ def forward( pass # initialize the global buffer # init_flash_attn_buffers(len(self.layers)) - GlobalBufferManager().init( + global global_buffer + global_buffer.init( self.config.num_hidden_layers, offload_percent=0.25, shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads], From 7690f834df3edb87da1a6b188f3f234ce315d186 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 27 Aug 2024 17:24:31 +0800 Subject: [PATCH 22/31] merge offload buffer --- src/llamafactory/easy_context/__init__.py | 4 +- .../dist_flash_attn/monkey_patch.py | 230 ++++++------------ .../dist_flash_attn/offload_buffer.py | 92 +++++++ src/llamafactory/hparams/finetuning_args.py | 10 +- src/llamafactory/train/sft/workflow.py | 7 +- tianqing_examples/Llama3-70B.sh | 3 + 6 files changed, 190 insertions(+), 156 deletions(-) create mode 100644 src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index 8f72c15786..8dbe0a5e30 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -64,7 +64,7 @@ def prepare_seq_parallel_sft_inputs( raise ValueError(f"Invalid seq_algo: {seq_algo}") def apply_seq_parallel_monkey_patch( - seq_algo, model, sp_size=None + seq_algo, model, sp_size=None, enable_offload=False, offload_percent=0. ): assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" assert model in ["llama", "mistral"], f"Invalid model: {model}" @@ -75,7 +75,7 @@ def apply_seq_parallel_monkey_patch( elif seq_algo == "zigzag_ring_attn" and model == "mistral": apply_zigzag_ring_attn_monkey_patch_mistral() elif seq_algo == "dist_flash_attn" and model == "llama": - apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size) + apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent) elif seq_algo == "ulysses_attn" and model == "llama": apply_ulysses_attn_monkey_patch_llama() else: diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index c12dafc7db..59dca890d7 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -5,21 +5,19 @@ import torch from torch import nn +import torch.nn.functional as F from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging from einops import rearrange from .lightseq_async_attn import _lightseq_forward, _lightseq_backward from .async_communication import initialize_distributed, reset_global_memory_buffer from transformers.cache_utils import Cache -import pycuda -import pycuda.driver as drv -import pycuda.autoinit -import numpy as np -import time, os +from .offload_buffer import offload_buffer, OffloadBuffer # define a global buffer to save flash attention outputs # it's called global because it saves the outputs for all layers @@ -32,97 +30,7 @@ # hooks for the gradients of residual global_hooks = [] -class Singleton(object): - _instance = None - def __new__(class_, *args, **kwargs): - if not isinstance(class_._instance, class_): - class_._instance = object.__new__(class_, *args, **kwargs) - return class_._instance - -class GlobalBufferManager(Singleton): - - def init(self, num_layers, offload_percent, shape, dtype, device): - torch.cuda.empty_cache() - if hasattr(self, 'initialized'): - return - self.layer_num = num_layers - self.gpu_layer_num = int(num_layers * offload_percent) - self.cpu_layer_num = num_layers - self.gpu_layer_num - self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) - bs, num_heads, seq_len, emb_size = shape - shape_h = [bs, seq_len, num_heads * emb_size] - shape_a = [bs, seq_len] - self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16) - self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16) - self.d2h_stream = drv.Stream() - self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] - self.initialized = True - - def save_flash_attn_out(self, layer_idx, out): - if layer_idx < 0: - layer_idx = self.layer_num + layer_idx - if layer_idx < self.cpu_layer_num: - drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) - else: - idx = layer_idx - self.cpu_layer_num - self.gpu_buffer[idx] = out - - def save_hidden_states(self, layer_idx, *hs): - if layer_idx < 0: - layer_idx = self.layer_num + layer_idx - hidden_state = hs[0] - position_id = hs[1] - if layer_idx < self.cpu_layer_num: - drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream) - drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream) - else: - idx = layer_idx - self.cpu_layer_num - self.hidden_state_gpu_buffer[idx] = hidden_state - self.position_id_gpu_buffer[idx] = position_id - - def get_flash_attn_out(self, layer_idx): - if layer_idx < 0: - layer_idx = self.layer_num + layer_idx - if layer_idx >= self.cpu_layer_num: - return self.gpu_buffer[layer_idx - self.cpu_layer_num] - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num - self.h2d_streams[idx].synchronize() - return self.gpu_buffer[idx] - - def get_hidden_states(self, layer_idx): - if layer_idx < 0: - layer_idx = self.layer_num + layer_idx - if layer_idx >= self.cpu_layer_num: - return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num - self.h2d_streams[idx].synchronize() - return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] - - def free_layer_gpu_buffer(self, layer_idx): - if layer_idx < 0: - layer_idx = self.layer_num + layer_idx - if layer_idx == self.layer_num - 1: - self.d2h_stream.synchronize() - cpu_layer_idx = layer_idx - self.gpu_layer_num - if layer_idx >= self.cpu_layer_num: - idx = layer_idx - self.cpu_layer_num - else: - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num - self.gpu_buffer[idx].grad = None - if cpu_layer_idx < 0: - self.gpu_buffer[idx] = None - self.hidden_state_gpu_buffer[idx] = None - self.position_id_gpu_buffer[idx] = None - return - drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) - drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) - drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) - -global_buffer = GlobalBufferManager() - +logger = logging.get_logger(__name__) def init_flash_attn_buffers(num_layers): # update the global buffer according to number of layers global global_flash_attn_out_buffer @@ -201,10 +109,12 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. + global offload_buffer ctx.inputs = [] - ctx.tensor_indices = {} + ctx.tensor_indices = [] + ctx.tensor_indices_dict = {} tensor_inputs = [] - global global_buffer + hidden_state = None position_ids = None for i, arg in enumerate(args): @@ -213,13 +123,16 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): ctx.inputs.append(None) else: if torch.is_tensor(arg): - # tensor_inputs.append(arg) - if len(arg.shape) == 3: - hidden_state = arg - ctx.tensor_indices[i] = 'hidden_state' - elif len(arg.shape) == 2: - position_ids = arg - ctx.tensor_indices[i] = 'position_ids' + if offload_buffer.enable_offload: + if len(arg.shape) == 3: + hidden_state = arg + ctx.tensor_indices_dict[i] = 'hidden_state' + elif len(arg.shape) == 2: + position_ids = arg + ctx.tensor_indices_dict[i] = 'position_ids' + else: + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) @@ -231,16 +144,18 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # lightseq version _, _, _, out, softmax_lse = _lightseq_forward(q, k, v, True, softmax_scale, comm_mode='lightseq') rng_state = None - # save flash attention output to global buffer - # save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + if offload_buffer.enable_offload: + offload_buffer.save_flash_attn_out(ctx.layer_idx, out) + offload_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids) + ctx.save_for_backward(softmax_lse) + else: + save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + tensor_inputs += [softmax_lse] + ctx.save_for_backward(*tensor_inputs) - global_buffer.save_flash_attn_out(ctx.layer_idx, out) - global_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids) - # tensor_inputs += [softmax_lse] ctx.softmax_scale = softmax_scale - - ctx.save_for_backward(softmax_lse) + return out, residual @staticmethod @@ -251,26 +166,28 @@ def backward(ctx, *args): " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") # Copy the list to avoid modifying original list. + global offload_buffer inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices - # tensors = ctx.saved_tensors - softmax_lse = ctx.saved_tensors[0] - # tensors, softmax_lse = tensors[:-1], tensors[-1] - global global_buffer - hidden_state, position_ids = global_buffer.get_hidden_states(ctx.layer_idx) - # Fill in inputs with appropriate saved tensors. - # Fill the flash attention output first - if ctx.layer_idx > 0: - # inputs[0] should be flash attention output - # inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) - inputs[0] = global_buffer.get_flash_attn_out(ctx.layer_idx-1) - # for i, idx in enumerate(tensor_indices): - # inputs[idx] = tensors[i] - for k, v in tensor_indices.items(): - if v == 'hidden_state': - inputs[k] = hidden_state - if v == 'position_ids': - inputs[k] = position_ids + tensor_indices_dict = ctx.tensor_indices_dict + tensors = ctx.saved_tensors + if offload_buffer.enable_offload: + softmax_lse = tensors[0] + hidden_state, position_ids = offload_buffer.get_hidden_states(ctx.layer_idx) + if ctx.layer_idx > 0: + inputs[0] = offload_buffer.get_flash_attn_out(ctx.layer_idx-1) + for k, v in tensor_indices_dict.items(): + if v == 'hidden_state': + inputs[k] = hidden_state + if v == 'position_ids': + inputs[k] = position_ids + else: + tensors, softmax_lse = tensors[:-1], tensors[-1] + if ctx.layer_idx > 0: + inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. @@ -299,7 +216,10 @@ def backward(ctx, *args): #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) # out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) - out = global_buffer.get_flash_attn_out(ctx.layer_idx) + if offload_buffer.enable_offload: + out = offload_buffer.get_flash_attn_out(ctx.layer_idx) + else: + out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) # todo get dout dout = args[0] @@ -315,11 +235,13 @@ def backward(ctx, *args): grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) # write flash attention output gradients to buffer - if ctx.layer_idx > 0: - # write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) - global_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad - # free_flash_attn_out_buffer(ctx.layer_idx) - global_buffer.free_layer_gpu_buffer(ctx.layer_idx) + if offload_buffer.enable_offload: + if ctx.layer_idx > 0: + offload_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad + offload_buffer.free_layer_gpu_buffer(ctx.layer_idx) + else: + if ctx.layer_idx > 0: + write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) return (None, None, None) + grads @@ -391,12 +313,14 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - global global_buffer + global offload_buffer # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first # inputs[0] should be flash attention output - # inputs[0] = get_flash_attn_out_from_global_buffer(-1) - inputs[0] = global_buffer.get_flash_attn_out(-1) + if offload_buffer.enable_offload: + inputs[0] = offload_buffer.get_flash_attn_out(-1) + else: + inputs[0] = get_flash_attn_out_from_global_buffer(-1) for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -435,9 +359,10 @@ def backward(ctx, *args): grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) # write flash attention output gradients to buffer - # write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) - global_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad - + if offload_buffer.enable_offload: + offload_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad + else: + write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) return (None, None) + grads def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): @@ -598,15 +523,14 @@ def forward( pass # initialize the global buffer # init_flash_attn_buffers(len(self.layers)) - global global_buffer - global_buffer.init( - self.config.num_hidden_layers, - offload_percent=0.25, - shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads], - dtype=hidden_states.dtype, - device=hidden_states.device - ) - + global offload_buffer + if offload_buffer.enable_offload: + offload_buffer.allocate( + self.config.num_hidden_layers, + shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads] + ) + else: + init_flash_attn_buffers(len(self.layers)) if use_cache: try: logger.warning_once( @@ -782,8 +706,10 @@ def llama_model_forward( attentions=outputs.attentions, ) -def apply_dist_flash_attn_monkey_patch_llama(sp_size=None): +def apply_dist_flash_attn_monkey_patch_llama(sp_size=None, enable_offload=False, offload_percent=0.): initialize_distributed(sp_size=sp_size) + global offload_buffer + offload_buffer = OffloadBuffer(enable_offload=enable_offload, offload_percent=offload_percent) transformers.models.llama.modeling_llama.LlamaModel.forward = forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = llama_model_forward diff --git a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py new file mode 100644 index 0000000000..8cbe8f9a69 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py @@ -0,0 +1,92 @@ +import pycuda +import pycuda.driver as drv +import pycuda.autoinit +import numpy as np + +class OffloadBuffer: + + def __init__(self, enable_offload, offload_percent): + self.enable_offload = enable_offload + self.offload_percent = offload_percent + self.allocated = False + + def allocate(self, num_layers, shape): + if self.allocated: + return + self.layer_num = num_layers + self.gpu_layer_num = int(num_layers * self.offload_percent) + self.cpu_layer_num = num_layers - self.gpu_layer_num + self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) + bs, num_heads, seq_len, emb_size = shape + shape_h = [bs, seq_len, num_heads * emb_size] + shape_a = [bs, seq_len] + self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16) + self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16) + self.d2h_stream = drv.Stream() + self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] + self.allocated = True + + def save_flash_attn_out(self, layer_idx, out): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx < self.cpu_layer_num: + drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.gpu_buffer[idx] = out + + def save_hidden_states(self, layer_idx, *hs): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + hidden_state = hs[0] + position_id = hs[1] + if layer_idx < self.cpu_layer_num: + drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream) + drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.hidden_state_gpu_buffer[idx] = hidden_state + self.position_id_gpu_buffer[idx] = position_id + + def get_flash_attn_out(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.h2d_streams[idx].synchronize() + return self.gpu_buffer[idx] + + def get_hidden_states(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.h2d_streams[idx].synchronize() + return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] + + def free_layer_gpu_buffer(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx == self.layer_num - 1: + self.d2h_stream.synchronize() + cpu_layer_idx = layer_idx - self.gpu_layer_num + if layer_idx >= self.cpu_layer_num: + idx = layer_idx - self.cpu_layer_num + else: + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.gpu_buffer[idx].grad = None + if cpu_layer_idx < 0: + self.gpu_buffer[idx] = None + self.hidden_state_gpu_buffer[idx] = None + self.position_id_gpu_buffer[idx] = None + return + drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + +offload_buffer = None \ No newline at end of file diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index eddbf3217f..5f9889c82b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -320,7 +320,15 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=-1, metadata={ "help": "allow using seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" - } + }, + ) + sp_enable_offload: bool = field( + default=False, + metadata={"help": "whether enable offload activation to cpu for dist_flash_attn"}, + ) + sp_offload_percent: float = field( + default=0.0, + metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"} ) def __post_init__(self): diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index e570d3ef71..b3420f1cc3 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,6 +17,7 @@ import torch import os from ...easy_context import apply_seq_parallel_monkey_patch +from ...easy_context.dist_flash_attn.offload_buffer import offload_buffer if TYPE_CHECKING: @@ -37,7 +38,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size, enable_offload=finetuning_args.sp_enable_offload, offload_percent=finetuning_args.sp_offload_percent) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -90,6 +91,10 @@ def run_sft( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + global offload_buffer + if offload_buffer is not None: + offload_buffer = None # # Evaluation # if training_args.do_eval: diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh index f88008de0e..93475ced61 100644 --- a/tianqing_examples/Llama3-70B.sh +++ b/tianqing_examples/Llama3-70B.sh @@ -6,6 +6,7 @@ WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} SP_SIZE=${SP_SIZE:-1} +SP_OFFLOAD_PERCENT=${SP_OFFLOAD_PERCENT:-0.8} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -32,6 +33,8 @@ src/train.py \ --finetuning_type full \ --parallel_mode dist_flash_attn \ --sp_size ${SP_SIZE} \ +--sp_enable_offload \ +--sp_offload_percent ${SP_OFFLOAD_PERCENT} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset long_sft_32k \ --template llama3 \ From 09f67e554682b6b1be2af79334f35c701f88b24c Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 3 Sep 2024 15:35:25 +0800 Subject: [PATCH 23/31] pycuda -> cudart --- .../dist_flash_attn/offload_buffer.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py index 8cbe8f9a69..dabba55f5e 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py +++ b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py @@ -1,7 +1,6 @@ -import pycuda -import pycuda.driver as drv -import pycuda.autoinit -import numpy as np +import cuda +import cuda.cudart +import torch class OffloadBuffer: @@ -14,26 +13,29 @@ def allocate(self, num_layers, shape): if self.allocated: return self.layer_num = num_layers - self.gpu_layer_num = int(num_layers * self.offload_percent) - self.cpu_layer_num = num_layers - self.gpu_layer_num + self.cpu_layer_num = int(num_layers * self.offload_percent) + self.gpu_layer_num = num_layers - self.cpu_layer_num self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) + self.cpu_buffer = [torch.empty([self.cpu_layer_num] + shape, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] bs, num_heads, seq_len, emb_size = shape shape_h = [bs, seq_len, num_heads * emb_size] shape_a = [bs, seq_len] self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16) + self.hidden_state_cpu_buffer = [torch.empty([self.cpu_layer_num] + shape_h, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16) - self.d2h_stream = drv.Stream() - self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] + self.position_id_cpu_buffer = [torch.empty([self.cpu_layer_num] + shape_a, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] + _, self.d2h_stream = cuda.cudart.cudaStreamCreate() + self.h2d_streams = [] + for i in range(self.gpu_layer_num): + _, h2d_stream = cuda.cudart.cudaStreamCreate() + self.h2d_streams.append(h2d_stream) self.allocated = True def save_flash_attn_out(self, layer_idx, out): if layer_idx < 0: layer_idx = self.layer_num + layer_idx if layer_idx < self.cpu_layer_num: - drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) + _ = cuda.cudart.cudaMemcpyAsync(self.cpu_buffer[layer_idx].data_ptr(), out.data_ptr(), out.nelement() * out.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) else: idx = layer_idx - self.cpu_layer_num self.gpu_buffer[idx] = out @@ -44,8 +46,8 @@ def save_hidden_states(self, layer_idx, *hs): hidden_state = hs[0] position_id = hs[1] if layer_idx < self.cpu_layer_num: - drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream) - drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream) + _ = cuda.cudart.cudaMemcpyAsync(self.hidden_state_cpu_buffer[layer_idx].data_ptr(), hidden_state.data_ptr(), hidden_state.nelement() * hidden_state.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) + _ = cuda.cudart.cudaMemcpyAsync(self.position_id_cpu_buffer[layer_idx].data_ptr(), position_id.data_ptr(), position_id.nelement() * position_id.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) else: idx = layer_idx - self.cpu_layer_num self.hidden_state_gpu_buffer[idx] = hidden_state @@ -57,7 +59,7 @@ def get_flash_attn_out(self, layer_idx): if layer_idx >= self.cpu_layer_num: return self.gpu_buffer[layer_idx - self.cpu_layer_num] idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num - self.h2d_streams[idx].synchronize() + _ = cuda.cudart.cudaStreamSynchronize(self.h2d_streams[idx]) return self.gpu_buffer[idx] def get_hidden_states(self, layer_idx): @@ -66,14 +68,14 @@ def get_hidden_states(self, layer_idx): if layer_idx >= self.cpu_layer_num: return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num - self.h2d_streams[idx].synchronize() + _ = cuda.cudart.cudaStreamSynchronize(self.h2d_streams[idx]) return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] def free_layer_gpu_buffer(self, layer_idx): if layer_idx < 0: layer_idx = self.layer_num + layer_idx if layer_idx == self.layer_num - 1: - self.d2h_stream.synchronize() + _ = cuda.cudart.cudaStreamSynchronize(self.d2h_stream) cpu_layer_idx = layer_idx - self.gpu_layer_num if layer_idx >= self.cpu_layer_num: idx = layer_idx - self.cpu_layer_num @@ -85,8 +87,19 @@ def free_layer_gpu_buffer(self, layer_idx): self.hidden_state_gpu_buffer[idx] = None self.position_id_gpu_buffer[idx] = None return - drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) - drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) - drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + cb = self.cpu_buffer[cpu_layer_idx] + hcb = self.hidden_state_cpu_buffer[cpu_layer_idx] + pcb = self.position_id_cpu_buffer[cpu_layer_idx] + gb = self.gpu_buffer[idx] + hgb = self.hidden_state_gpu_buffer[idx] + pgb = self.position_id_gpu_buffer[idx] + _ = cuda.cudart.cudaMemcpyAsync(gb.data_ptr(), cb.data_ptr(), gb.nelement() * gb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + _ = cuda.cudart.cudaMemcpyAsync(hgb.data_ptr(), hcb.data_ptr(), hgb.nelement() * hgb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + _ = cuda.cudart.cudaMemcpyAsync(pgb.data_ptr(), pcb.data_ptr(), pgb.nelement() * pgb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + + def __del__(self): + cuda.cudart.cudaStreamDestroy(self.d2h_stream) + for i in range(self.gpu_layer_num): + cuda.cudart.cudaStreamDestroy(self.h2d_streams[i]) offload_buffer = None \ No newline at end of file From 949763f2b2fbbbdea35e1b99fb27bc8b01b869a7 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Wed, 4 Sep 2024 09:47:05 +0800 Subject: [PATCH 24/31] add offload examples --- .../dist_flash_attn/offload_buffer.py | 15 ++--- tianqing_examples/Llama3-70B-sp-offload.sh | 60 +++++++++++++++++++ tianqing_examples/Llama3-70B.sh | 3 - 3 files changed, 68 insertions(+), 10 deletions(-) create mode 100644 tianqing_examples/Llama3-70B-sp-offload.sh diff --git a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py index dabba55f5e..ed20759ab0 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py +++ b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py @@ -16,14 +16,14 @@ def allocate(self, num_layers, shape): self.cpu_layer_num = int(num_layers * self.offload_percent) self.gpu_layer_num = num_layers - self.cpu_layer_num self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.cpu_buffer = [torch.empty([self.cpu_layer_num] + shape, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] + self.cpu_buffer = [torch.empty(shape, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] bs, num_heads, seq_len, emb_size = shape shape_h = [bs, seq_len, num_heads * emb_size] shape_a = [bs, seq_len] self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.hidden_state_cpu_buffer = [torch.empty([self.cpu_layer_num] + shape_h, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] + self.hidden_state_cpu_buffer = [torch.empty(shape_h, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] - self.position_id_cpu_buffer = [torch.empty([self.cpu_layer_num] + shape_a, dtype=torch.bfloat16) for _ in range(self.cpu_layer_num)] + self.position_id_cpu_buffer = [torch.empty(shape_a, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] _, self.d2h_stream = cuda.cudart.cudaStreamCreate() self.h2d_streams = [] for i in range(self.gpu_layer_num): @@ -98,8 +98,9 @@ def free_layer_gpu_buffer(self, layer_idx): _ = cuda.cudart.cudaMemcpyAsync(pgb.data_ptr(), pcb.data_ptr(), pgb.nelement() * pgb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) def __del__(self): - cuda.cudart.cudaStreamDestroy(self.d2h_stream) - for i in range(self.gpu_layer_num): - cuda.cudart.cudaStreamDestroy(self.h2d_streams[i]) + if self.allocated: + cuda.cudart.cudaStreamDestroy(self.d2h_stream) + for i in range(self.gpu_layer_num): + cuda.cudart.cudaStreamDestroy(self.h2d_streams[i]) -offload_buffer = None \ No newline at end of file +offload_buffer = None diff --git a/tianqing_examples/Llama3-70B-sp-offload.sh b/tianqing_examples/Llama3-70B-sp-offload.sh new file mode 100644 index 0000000000..93475ced61 --- /dev/null +++ b/tianqing_examples/Llama3-70B-sp-offload.sh @@ -0,0 +1,60 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-32768} +SP_SIZE=${SP_SIZE:-1} +SP_OFFLOAD_PERCENT=${SP_OFFLOAD_PERCENT:-0.8} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--sp_size ${SP_SIZE} \ +--sp_enable_offload \ +--sp_offload_percent ${SP_OFFLOAD_PERCENT} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset long_sft_32k \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_steps 10 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 1.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 1000 diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh index 93475ced61..f88008de0e 100644 --- a/tianqing_examples/Llama3-70B.sh +++ b/tianqing_examples/Llama3-70B.sh @@ -6,7 +6,6 @@ WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} SP_SIZE=${SP_SIZE:-1} -SP_OFFLOAD_PERCENT=${SP_OFFLOAD_PERCENT:-0.8} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -33,8 +32,6 @@ src/train.py \ --finetuning_type full \ --parallel_mode dist_flash_attn \ --sp_size ${SP_SIZE} \ ---sp_enable_offload \ ---sp_offload_percent ${SP_OFFLOAD_PERCENT} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset long_sft_32k \ --template llama3 \ From d81e737a02fe1e52f199cc63ea4816c2b96aa024 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Tue, 22 Oct 2024 19:16:18 +0800 Subject: [PATCH 25/31] add dp&sp for zigzag and ulysses --- src/llamafactory/easy_context/__init__.py | 6 +-- .../easy_context/ulysses_attn/monkey_patch.py | 39 ++++++++++++---- .../zigzag_ring_attn/monkey_patch.py | 45 ++++++++++++++----- src/llamafactory/train/sft/workflow.py | 5 --- 4 files changed, 67 insertions(+), 28 deletions(-) diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index 8dbe0a5e30..bffbbba570 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -71,13 +71,13 @@ def apply_seq_parallel_monkey_patch( if seq_algo == "data_parallel": return elif seq_algo == "zigzag_ring_attn" and model == "llama": - apply_zigzag_ring_attn_monkey_patch_llama() + apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size) elif seq_algo == "zigzag_ring_attn" and model == "mistral": - apply_zigzag_ring_attn_monkey_patch_mistral() + apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=sp_size) elif seq_algo == "dist_flash_attn" and model == "llama": apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent) elif seq_algo == "ulysses_attn" and model == "llama": - apply_ulysses_attn_monkey_patch_llama() + apply_ulysses_attn_monkey_patch_llama(sp_size=sp_size) else: raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") diff --git a/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py index a2cba43022..0ba10141b8 100644 --- a/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py @@ -5,7 +5,7 @@ import torch.utils.checkpoint from yunchang.ulysses import UlyssesAttention -ulysses_attn = UlyssesAttention() +ulysses_attn = None def new_flash_attn_forward( self, @@ -50,12 +50,12 @@ def new_decoder_forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance( - self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - ) or isinstance( - self.self_attn, - transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + # assert isinstance( + # self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + # ) or isinstance( + # self.self_attn, + # transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + # ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." if "padding_mask" in kwargs: warnings.warn( @@ -95,8 +95,29 @@ def new_decoder_forward( return outputs - -def apply_ulysses_attn_monkey_patch_llama(): +def get_sp_process_group(sequence_parallel_size=None): + if sequence_parallel_size is None: + return None + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") + if sequence_parallel_size is None or sequence_parallel_size == -1: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + rank = torch.distributed.get_rank() + + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + if rank in ranks: + group = torch.distributed.new_group(ranks) + return group + +def apply_ulysses_attn_monkey_patch_llama(sp_size=None): + sp_group = get_sp_process_group(sp_size) + global ulysses_attn + ulysses_attn = UlyssesAttention(sequence_process_group=sp_group) transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( new_flash_attn_forward ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py index 4ebcdb7d05..596d1d241d 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -4,7 +4,7 @@ import torch import torch.utils.checkpoint from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func - +from functools import partial def new_flash_attn_forward( self, @@ -16,6 +16,7 @@ def new_flash_attn_forward( dropout=0.0, softmax_scale=None, use_sliding_windows=False, + group=None ): if not self._flash_attn_uses_top_left_mask: causal = self.is_causal @@ -33,6 +34,7 @@ def new_flash_attn_forward( dropout, softmax_scale, causal=causal, + group=group ) return attn_output @@ -49,12 +51,12 @@ def new_decoder_forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance( - self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - ) or isinstance( - self.self_attn, - transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + # assert isinstance( + # self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + # ) or isinstance( + # self.self_attn, + # transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + # ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." if "padding_mask" in kwargs: warnings.warn( @@ -95,18 +97,39 @@ def new_decoder_forward( return outputs -def apply_zigzag_ring_attn_monkey_patch_llama(): +def get_sp_process_group(sequence_parallel_size=None): + if sequence_parallel_size is None: + return None + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") + if sequence_parallel_size is None or sequence_parallel_size == -1: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + rank = torch.distributed.get_rank() + + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + if rank in ranks: + group = torch.distributed.new_group(ranks) + return group + +def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): + sp_group = get_sp_process_group(sp_size) transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward + partial(new_flash_attn_forward, group=sp_group) ) transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( new_decoder_forward ) -def apply_zigzag_ring_attn_monkey_patch_mistral(): +def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None): + sp_group = get_sp_process_group(sp_size) transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward + partial(new_flash_attn_forward, group=sp_group) ) transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( new_decoder_forward diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index b3420f1cc3..6aeda45187 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,7 +17,6 @@ import torch import os from ...easy_context import apply_seq_parallel_monkey_patch -from ...easy_context.dist_flash_attn.offload_buffer import offload_buffer if TYPE_CHECKING: @@ -91,10 +90,6 @@ def run_sft( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - - global offload_buffer - if offload_buffer is not None: - offload_buffer = None # # Evaluation # if training_args.do_eval: From 6b964cd872a999baaf5a5e5b0103e4f3e27e1168 Mon Sep 17 00:00:00 2001 From: qianhao Date: Mon, 28 Oct 2024 07:59:26 +0000 Subject: [PATCH 26/31] fix bug --- .../zigzag_ring_attn/monkey_patch.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py index 596d1d241d..87c1d2c434 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -4,7 +4,7 @@ import torch import torch.utils.checkpoint from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func -from functools import partial +from functools import partialmethod def new_flash_attn_forward( self, @@ -51,12 +51,12 @@ def new_decoder_forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - # assert isinstance( - # self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - # ) or isinstance( - # self.self_attn, - # transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - # ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." if "padding_mask" in kwargs: warnings.warn( @@ -119,7 +119,7 @@ def get_sp_process_group(sequence_parallel_size=None): def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): sp_group = get_sp_process_group(sp_size) transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - partial(new_flash_attn_forward, group=sp_group) + partialmethod(new_flash_attn_forward, group=sp_group) ) transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( new_decoder_forward @@ -129,7 +129,7 @@ def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None): sp_group = get_sp_process_group(sp_size) transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( - partial(new_flash_attn_forward, group=sp_group) + partialmethod(new_flash_attn_forward, group=sp_group) ) transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( new_decoder_forward From 19236bb5f57d027172361cb91586c44be8ff1e8e Mon Sep 17 00:00:00 2001 From: qianhao Date: Mon, 28 Oct 2024 08:15:18 +0000 Subject: [PATCH 27/31] update example --- tianqing_examples/Llama3-70B.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh index f88008de0e..a3d9e409c1 100644 --- a/tianqing_examples/Llama3-70B.sh +++ b/tianqing_examples/Llama3-70B.sh @@ -1,12 +1,15 @@ # You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and # choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. -MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-70B/"} NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} +PARALLEL_MODE=${PARALLEL_MODE:-"dist_flash_attn"} +DATASET=${DATASET:-"long_sft_32k"} +ACC=${ACC:-4} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true export NCCL_DEBUG=WARN @@ -30,10 +33,10 @@ src/train.py \ --stage sft \ --do_train \ --finetuning_type full \ ---parallel_mode dist_flash_attn \ +--parallel_mode ${PARALLEL_MODE} \ --sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ ---dataset long_sft_32k \ +--dataset ${DATASET} \ --template llama3 \ --cutoff_len ${SEQ_LEN} \ --max_steps 10 \ @@ -45,7 +48,7 @@ src/train.py \ --plot_loss \ --overwrite_output_dir \ --per_device_train_batch_size ${BATCH_SIZE} \ ---gradient_accumulation_steps 4 \ +--gradient_accumulation_steps ${ACC} \ --learning_rate 2e-5 \ --num_train_epochs 1.0 \ --lr_scheduler_type cosine \ From a2159210af65e8a4096a36ee8a335a8d95f94b48 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Fri, 6 Dec 2024 14:46:43 +0800 Subject: [PATCH 28/31] update transformers to 4.46.3 --- .../zigzag_ring_attn/monkey_patch.py | 139 ++++++++++++++++-- .../zigzag_ring_attn/prepare_inputs.py | 2 +- src/llamafactory/train/sft/trainer.py | 70 ++++++--- 3 files changed, 175 insertions(+), 36 deletions(-) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py index 87c1d2c434..d27bcf0de1 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -4,7 +4,8 @@ import torch import torch.utils.checkpoint from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func -from functools import partialmethod +from functools import partialmethod, partial +import inspect def new_flash_attn_forward( self, @@ -39,7 +40,42 @@ def new_flash_attn_forward( return attn_output +def new_flash_attn_forward_v2( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal, + dropout=0.0, + position_ids=None, + softmax_scale=None, + sliding_window=None, + use_top_left_mask=False, + softcap=None, + group=None +): + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert sliding_window is None + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + group=group + ) + return attn_output + def new_decoder_forward( self, hidden_states: torch.Tensor, @@ -96,6 +132,63 @@ def new_decoder_forward( return outputs +def new_decoder_forward_v2( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs def get_sp_process_group(sequence_parallel_size=None): if sequence_parallel_size is None: @@ -118,19 +211,39 @@ def get_sp_process_group(sequence_parallel_size=None): def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): sp_group = get_sp_process_group(sp_size) - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - partialmethod(new_flash_attn_forward, group=sp_group) - ) - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) + if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + partialmethod(new_flash_attn_forward, group=sp_group) + ) + else: + transformers.models.llama.modeling_llama._flash_attention_forward = ( + partial(new_flash_attn_forward_v2, group=sp_group) + ) + if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward_v2 + ) + else: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None): sp_group = get_sp_process_group(sp_size) - transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( - partialmethod(new_flash_attn_forward, group=sp_group) - ) - transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( - new_decoder_forward - ) + if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + partialmethod(new_flash_attn_forward, group=sp_group) + ) + else: + transformers.models.llama.modeling_llama._flash_attention_forward = ( + partial(new_flash_attn_forward_v2, group=sp_group) + ) + if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward_v2 + ) + else: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py index 24f7e4d467..6d2925aa41 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py @@ -70,7 +70,7 @@ def prepare_zigzag_ring_attn_sft_inputs( ) return { "input_ids": local_input_ids, - "attention_mask": local_attention_mask, + "attention_mask": None, "position_ids": local_position_ids, "labels": local_labels, } \ No newline at end of file diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index bc4a9c5a17..7a379fafdc 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -137,17 +137,22 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - if self.label_smoother is not None and "labels" in inputs: + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: labels = None + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. @@ -160,34 +165,55 @@ def compute_loss(self, model, inputs, return_outputs=False): model_name = unwrapped_model.base_model.model._get_name() else: model_name = unwrapped_model._get_name() - if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) # We don't use .loss here since the model may return tuples instead of ModelOutput. if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] else: - sp_size = self.finetuning_args.sp_size - loss_fn = CrossEntropyLoss(reduction='sum') - labels = inputs.pop("labels") - logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] - valid_label_cnt = (labels!=-100).sum(1)[None, :] - valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) - n_gpus = valid_label_cnt_gather.shape[0] - if sp_size == -1: - sp_size = n_gpus - dp_rank = self.accelerator.process_index // sp_size - valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() - shift_logits = logits.contiguous() - shift_labels = labels.contiguous() - bs = len(shift_labels) - loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) - for b in range(bs): - normalizer=valid_label_cnt_all[b].item() - loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer - loss = loss.mean()*sp_size + if num_items_in_batch is None or (num_items_in_batch is not None and not self.args.average_tokens_across_devices): + sp_size = self.finetuning_args.sp_size + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + valid_label_cnt = (labels!=-100).sum(1)[None, :] + valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + n_gpus = valid_label_cnt_gather.shape[0] + if sp_size == -1: + sp_size = n_gpus + dp_rank = self.accelerator.process_index // sp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + bs = len(shift_labels) + loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) + for b in range(bs): + normalizer=valid_label_cnt_all[b].item() + loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer + loss = loss.mean()*sp_size + else: + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + loss = loss_fn(shift_logits, shift_labels)/num_items_in_batch + + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes return (loss, outputs) if return_outputs else loss From 358cbad24908c7c6e83dd932db2c8931300f3af9 Mon Sep 17 00:00:00 2001 From: qianhao Date: Mon, 9 Dec 2024 03:28:44 +0000 Subject: [PATCH 29/31] fix bug --- .../easy_context/dist_flash_attn/monkey_patch.py | 6 +++++- src/llamafactory/hparams/finetuning_args.py | 4 ++++ src/llamafactory/train/sft/trainer.py | 4 +++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 59dca890d7..509db1e5e2 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -664,6 +664,8 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -691,10 +693,12 @@ def llama_model_forward( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) logits = logits.float() loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 5f9889c82b..c525502694 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -330,6 +330,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=0.0, metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"} ) + per_instance_loss: bool = field( + default=False, + metadata={"help": "if update transformers to 4.46.3, the loss will be calculated in a global batch by default, enable per_instance_loss will calculate loss in each instance"} + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 7a379fafdc..dbd8420a26 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -181,8 +181,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # We don't use .loss here since the model may return tuples instead of ModelOutput. if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes else: - if num_items_in_batch is None or (num_items_in_batch is not None and not self.args.average_tokens_across_devices): + if num_items_in_batch is None or self.finetuning_args.per_instance_loss: sp_size = self.finetuning_args.sp_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") From 1eead845892a36bea93bf86b654c6f15b46ef610 Mon Sep 17 00:00:00 2001 From: qianhao Date: Wed, 11 Dec 2024 01:39:44 +0000 Subject: [PATCH 30/31] fix bug when parallel_mode=data_parallel --- .../dist_flash_attn/monkey_patch.py | 2 +- src/llamafactory/train/sft/trainer.py | 98 +++++++++++++++++-- 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 509db1e5e2..6903c5f2bc 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -697,7 +697,7 @@ def llama_model_forward( logits = logits.float() loss = None - if labels is not None: + if labels is not None and hasattr(self, 'loss_function'): loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index dbd8420a26..d00acdf76d 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -11,7 +11,7 @@ from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from torch.utils.data import DataLoader -from transformers.utils import is_datasets_available +from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled from transformers.trainer_utils import seed_worker import datasets from torch.nn import CrossEntropyLoss @@ -137,6 +137,77 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): + def training_step( + self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + + use_per_instance_loss = self.finetuning_args.per_instance_loss + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + del inputs + if (hasattr(self.args, "torch_empty_cache_steps") + and self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + # Finally we need to normalize the loss for reporting + if num_items_in_batch is None or self.finetuning_args.per_instance_loss: + return loss.detach() / self.args.gradient_accumulation_steps + return loss.detach() + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -144,11 +215,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N Subclass and override for custom behavior. """ from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + if not hasattr(self, 'compute_loss_func'): + self.compute_loss_func = None if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: labels = None - if self.model_accepts_loss_kwargs: + if hasattr(self, 'model_accepts_loss_kwargs') and self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch @@ -173,15 +246,23 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: loss = self.label_smoother(outputs, labels) else: - if isinstance(outputs, dict) and "loss" not in outputs: - raise ValueError( - "The model did not return a loss from the inputs, only the following keys: " - f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." - ) # We don't use .loss here since the model may return tuples instead of ModelOutput. if self.finetuning_args.parallel_mode== "data_parallel": + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + if not hasattr(self.args, 'average_tokens_across_devices'): + self.args.average_tokens_across_devices = None + if not hasattr(self, 'model_accepts_loss_kwargs'): + self.model_accepts_loss_kwargs= None + if self.finetuning_args.per_instance_loss and num_items_in_batch is not None: + labels = inputs.pop("labels") + valid_label_cnt = (labels!=-100).sum() + loss *= num_items_in_batch / valid_label_cnt + elif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes else: if num_items_in_batch is None or self.finetuning_args.per_instance_loss: @@ -205,6 +286,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer loss = loss.mean()*sp_size else: + assert self.args.average_tokens_across_devices is True, "must ensure average_tokens_across_devices if parallel_mode is not data_parallel" loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] From e6f1947dad9cbafd38b15862a4004255286a46ce Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Wed, 11 Dec 2024 11:17:19 +0800 Subject: [PATCH 31/31] remove per_instance_loss --- src/llamafactory/easy_context/__init__.py | 3 - .../zigzag_ring_attn/monkey_patch.py | 20 ----- src/llamafactory/hparams/finetuning_args.py | 4 - src/llamafactory/train/sft/trainer.py | 81 +------------------ 4 files changed, 3 insertions(+), 105 deletions(-) diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index bffbbba570..a1d1b02a79 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -2,7 +2,6 @@ from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama -from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama @@ -72,8 +71,6 @@ def apply_seq_parallel_monkey_patch( return elif seq_algo == "zigzag_ring_attn" and model == "llama": apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size) - elif seq_algo == "zigzag_ring_attn" and model == "mistral": - apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=sp_size) elif seq_algo == "dist_flash_attn" and model == "llama": apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent) elif seq_algo == "ulysses_attn" and model == "llama": diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py index d27bcf0de1..250c1ea6f1 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -227,23 +227,3 @@ def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( new_decoder_forward ) - - -def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None): - sp_group = get_sp_process_group(sp_size) - if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'): - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - partialmethod(new_flash_attn_forward, group=sp_group) - ) - else: - transformers.models.llama.modeling_llama._flash_attention_forward = ( - partial(new_flash_attn_forward_v2, group=sp_group) - ) - if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args: - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward_v2 - ) - else: - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index c525502694..5f9889c82b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -330,10 +330,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=0.0, metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"} ) - per_instance_loss: bool = field( - default=False, - metadata={"help": "if update transformers to 4.46.3, the loss will be calculated in a global batch by default, enable per_instance_loss will calculate loss in each instance"} - ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d00acdf76d..ca30bcc7a2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -11,7 +11,7 @@ from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from torch.utils.data import DataLoader -from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled +from transformers.utils import is_datasets_available from transformers.trainer_utils import seed_worker import datasets from torch.nn import CrossEntropyLoss @@ -137,77 +137,6 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - def training_step( - self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. - - Subclass and override to inject custom behavior. - - Args: - model (`nn.Module`): - The model to train. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - - Return: - `torch.Tensor`: The tensor with training loss on this batch. - """ - model.train() - if hasattr(self.optimizer, "train") and callable(self.optimizer.train): - self.optimizer.train() - - use_per_instance_loss = self.finetuning_args.per_instance_loss - - inputs = self._prepare_inputs(inputs) - if is_sagemaker_mp_enabled(): - loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) - return loss_mb.reduce_mean().detach().to(self.args.device) - - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) - - del inputs - if (hasattr(self.args, "torch_empty_cache_steps") - and self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_mlu_available(): - torch.mlu.empty_cache() - elif is_torch_musa_available(): - torch.musa.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() - else: - torch.cuda.empty_cache() - - kwargs = {} - - # For LOMO optimizers you need to explicitly use the learnign rate - # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - # kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - if self.use_apex: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - self.accelerator.backward(loss, **kwargs) - # Finally we need to normalize the loss for reporting - if num_items_in_batch is None or self.finetuning_args.per_instance_loss: - return loss.detach() / self.args.gradient_accumulation_steps - return loss.detach() - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -258,14 +187,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.args.average_tokens_across_devices = None if not hasattr(self, 'model_accepts_loss_kwargs'): self.model_accepts_loss_kwargs= None - if self.finetuning_args.per_instance_loss and num_items_in_batch is not None: - labels = inputs.pop("labels") - valid_label_cnt = (labels!=-100).sum() - loss *= num_items_in_batch / valid_label_cnt - elif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes else: - if num_items_in_batch is None or self.finetuning_args.per_instance_loss: + if num_items_in_batch is None: sp_size = self.finetuning_args.sp_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels")