diff --git a/.gitignore b/.gitignore index 0355c66607..8dc3597ec5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +# local directories to ignore +/output/ +*.json +*llmtuner* +/wandb/ +/examples/ +/data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/data/dataset_info.json b/data/dataset_info.json index 1d226b3adc..70261447e7 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -8,6 +8,12 @@ "alpaca_zh_demo": { "file_name": "alpaca_zh_demo.json" }, + "long_sft_32k": { + "file_name": "sample_long_sft_32k_48M.json" + }, + "long_sft_128k": { + "file_name": "sample_long_sft_128k.parquet" + }, "glaive_toolcall_en_demo": { "file_name": "glaive_toolcall_en_demo.json", "formatting": "sharegpt", @@ -551,4 +557,4 @@ }, "folder": "python" } -} \ No newline at end of file +} diff --git a/examples/accelerate/ds_multi_nodes.yaml b/examples/accelerate/ds_multi_nodes.yaml new file mode 100644 index 0000000000..0b465fae9a --- /dev/null +++ b/examples/accelerate/ds_multi_nodes.yaml @@ -0,0 +1,15 @@ +debug: false +deepspeed_config: + deepspeed_config_file: examples/deepspeed/ds_z3_offload_config.json + deepspeed_multinode_launcher: standard + zero3_init_flag: true +distributed_type: DEEPSPEED +num_processes: 16 +downcast_bf16: 'no' +main_training_function: main +rdzv_backend: c10d +same_network: false +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/deepspeed/ds_z3_offload_config.json b/examples/deepspeed/ds_z3_offload_config.json index 026aabbcda..b00b8bc72e 100644 --- a/examples/deepspeed/ds_z3_offload_config.json +++ b/examples/deepspeed/ds_z3_offload_config.json @@ -17,13 +17,8 @@ }, "zero_optimization": { "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, "offload_param": { - "device": "cpu", - "pin_memory": true + "device": "cpu" }, "overlap_comm": true, "contiguous_gradients": true, @@ -34,5 +29,6 @@ "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true - } -} \ No newline at end of file + }, + "steps_per_print":1 +} diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index b08691d38b..f2ca6b9e9d 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,4 +1,4 @@ -from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SeqParallelDataCollatorForLanguageModeling from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer @@ -7,6 +7,7 @@ __all__ = [ "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", + "SeqParallelDataCollatorForLanguageModeling", "Role", "split_dataset", "get_dataset", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1dc8dd8d38..29023f25f9 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -2,8 +2,9 @@ from typing import Any, Dict, Sequence import torch -from transformers import DataCollatorForSeq2Seq - +from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union +from llamafactory.easy_context import prepare_seq_parallel_sft_inputs @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): @@ -79,3 +80,85 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor batch["kto_tags"] = torch.tensor(kto_tags) return batch + +@dataclass +class SeqParallelDataCollator(DataCollatorForSeq2Seq): + r""" + Data collator for sequence parallel in supervised finetune(sft) stage. + """ + seq_algo: str = "data_parallel", + sp_size: int = -1 + rank: int = 0 + world_size: int = 8 + device: Optional[Any] = None + + def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: + batch = super().__call__(features, return_tensors) + if self.seq_algo == "data_parallel": + return batch + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size + bs = len(input_ids) + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size + input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] + attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] + labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=sp_rank, + world_size=world_size, + device=self.device) + return batch + + +@dataclass +class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): + r""" + Data collator for sequence parallel in pretrain(pt) stage. + Reuse the sequence parallel distributing function for sft stage. + """ + seq_algo: str = "data_parallel" + sp_size: int = -1 + rank: int = 0 + world_size: int = 8 + device: Optional[Any] = None + + def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().__call__(examples) + if self.seq_algo == "data_parallel": + return batch + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size + bs = len(input_ids) + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size + input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] + attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] + labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] + batch = prepare_seq_parallel_sft_inputs(self.seq_algo, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + labels=labels, + rank=sp_rank, + world_size=world_size, + device=self.device) + return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index ba426f8156..8b356e380e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -143,6 +143,18 @@ def get_dataset( if has_tokenized_data(data_args.tokenized_path): logger.warning("Loading dataset from disk will ignore other data arguments.") dataset = load_from_disk(data_args.tokenized_path) + # ---lsy--- + to_remove = [col for col in dataset.column_names if col != "input_ids"] + # import copy + # first_item = copy.deepcopy(dataset[0]['input_ids']) + def update_column(example): + example['input_ids'] = example['input_ids'][:data_args.cutoff_len] + # example['input_ids'] = first_item[:data_args.cutoff_len] + return example + + # # 使用 map 方法添加新列 + dataset = dataset.map(update_column,remove_columns=to_remove) + # ---lsy--- logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) if data_args.streaming: dataset = dataset.to_iterable_dataset() @@ -166,6 +178,7 @@ def get_dataset( data_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) + logger.debug(f"remove_columns:{column_names}") kwargs = {} if not data_args.streaming: kwargs = dict( @@ -175,9 +188,9 @@ def get_dataset( ) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) - + if data_args.tokenized_path is not None: - if training_args.should_save: + if training_args.should_save: dataset.save_to_disk(data_args.tokenized_path) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py new file mode 100644 index 0000000000..a1d1b02a79 --- /dev/null +++ b/src/llamafactory/easy_context/__init__.py @@ -0,0 +1,85 @@ +from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs, prepare_dist_flash_attn_sft_inputs +from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama +from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs +from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama +from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch +from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs +from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama +import torch +import torch.nn.functional as F + +def prepare_seq_parallel_inputs( + seq_algo, input_ids, position_ids, target_ids, rank, world_size, device +): + if seq_algo == "zigzag_ring_attn": + return prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "dist_flash_attn": + return prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "ulysses_attn": + return prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "data_parallel": + return { + "local_input_ids": input_ids.to(device), + "local_position_ids": position_ids.to(device), + "local_target_ids": target_ids.to(device), + } + else: + raise ValueError(f"Invalid seq_algo: {seq_algo}") + +def prepare_seq_parallel_sft_inputs( + seq_algo, input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + shift_labels = F.pad(labels, [0, 1], 'constant', -100)[:, 1:] + if seq_algo == "zigzag_ring_attn": + return prepare_zigzag_ring_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "dist_flash_attn": + return prepare_dist_flash_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "ulysses_attn": + return prepare_ulysses_attn_sft_inputs( + input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device + ) + elif seq_algo == "data_parallel": + return { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "target_ids": labels, + } + else: + raise ValueError(f"Invalid seq_algo: {seq_algo}") + +def apply_seq_parallel_monkey_patch( + seq_algo, model, sp_size=None, enable_offload=False, offload_percent=0. +): + assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" + assert model in ["llama", "mistral"], f"Invalid model: {model}" + if seq_algo == "data_parallel": + return + elif seq_algo == "zigzag_ring_attn" and model == "llama": + apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size) + elif seq_algo == "dist_flash_attn" and model == "llama": + apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent) + elif seq_algo == "ulysses_attn" and model == "llama": + apply_ulysses_attn_monkey_patch_llama(sp_size=sp_size) + else: + raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") + +def prepare_dataloader(seq_algo, dataloader, acclerator): + if seq_algo == "data_parallel": + return acclerator.prepare(dataloader) + else: + return dataloader diff --git a/src/llamafactory/easy_context/dist_flash_attn/README.md b/src/llamafactory/easy_context/dist_flash_attn/README.md new file mode 100644 index 0000000000..2025265c3e --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/README.md @@ -0,0 +1,11 @@ +# LightSeq +Taken from https://github.com/RulinShao/LightSeq. All credits to the authors. + +``` +@article{li2023lightseq, + title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, + author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, + journal={arXiv preprint arXiv:2310.03294}, + year={2023} +} +``` \ No newline at end of file diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py new file mode 100644 index 0000000000..68b35b5ae6 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -0,0 +1,531 @@ +import threading +import math +import os + +import torch +import torch.distributed as dist +from torch.distributed import batch_isend_irecv, P2POp, isend, irecv, get_process_group_ranks + +# Sequence parallel group that the current rank belongs to. +_SEQUENCE_PARALLEL_GROUP = None + +# These values enable us to change the sequence parallel sizes on the fly. +_SEQUENCE_PARALLEL_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +# Global buffer for P2P +_PEER_Q = None +_PEER_K = None +_PEER_V = None +_PEER_M = None +_PEER_L = None +_PEER_O = None +_PEER_Q_BWD = None +_PEER_K_BWD = None +_PEER_V_BWD = None +_PEER_O_BWD = None + +_DELTA_DQ = None +_PEER_L = None +_DELTA_DK = None +_DELTA_DV = None +_DK_DELTA_FROM_PEER = None +_DV_DELTA_FROM_PEER = None +_PEER_DO = None + + +_fwd_send_volume = 0 +_fwd_recv_volume = 0 +_bwd_send_volume = 0 +_bwd_recv_volume = 0 + +def initialize_distributed(sp_size=None): + if dist.is_initialized(): + if dist.get_rank() == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + else: + if int(os.environ["RANK"]) == 0: + print("Initializing Torch distributed.") + dist.init_process_group(backend="nccl") + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + global_world_size = dist.get_world_size() + torch.cuda.set_device(dist.get_rank() % local_world_size) + + _initialize_sequence_parallel(sp_size) + # create_nccl_communicators() + +def _initialize_sequence_parallel(sequence_parallel_size=None): + # Get world size and rank. Ensure some consistencies. + # assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") + + if sequence_parallel_size is None or sequence_parallel_size == -1: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + + rank = torch.distributed.get_rank() + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_RANK + global _SEQUENCE_PARALLEL_SIZE + + assert ( + _SEQUENCE_PARALLEL_GROUP is None + ), 'sequence parallel group is already initialized' + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + _SEQUENCE_PARALLEL_RANK = ranks.index(rank) + _SEQUENCE_PARALLEL_SIZE = len(ranks) + + if dist.get_rank() == 0: + print("************ Finish sequence pralell group Initialization. ***********") + # _set_global_memory_buffer() + +def maybe_get_set_global_memory_buffer(q, k, v, m, l, o): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + if _PEER_Q is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer.") + except: + print("Initializing global memoery buffer.") + _PEER_Q = [torch.empty_like(q) for _ in range(2)] + _PEER_K = [torch.empty_like(k) for _ in range(2)] + _PEER_V = [torch.empty_like(v) for _ in range(2)] + _PEER_M = [torch.empty_like(m) for _ in range(2)] + _PEER_L = [torch.empty_like(l) for _ in range(2)] + _PEER_O = [torch.empty_like(o) for _ in range(2)] + + return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + +def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do): + global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER,_PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO + if _DELTA_DQ is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer for backward.") + except: + print("Initializing global memoery buffer for backward.") + _DELTA_DQ = [torch.empty_like(dq) for _ in range(2)] + _DELTA_DK = [torch.empty_like(dk) for _ in range(2)] + _DELTA_DV = [torch.empty_like(dv) for _ in range(2)] + _PEER_L = [torch.empty_like(L) for _ in range(2)] + + _DK_DELTA_FROM_PEER = torch.empty_like(dk) + _DV_DELTA_FROM_PEER = torch.empty_like(dv) + + # may already be initailized in the forward call. + # current forward and backward needs a transpose in q's format + _PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)] + _PEER_K_BWD = [torch.empty_like(k) for _ in range(2)] + _PEER_V_BWD = [torch.empty_like(v) for _ in range(2)] + _PEER_O_BWD = [torch.empty_like(o) for _ in range(2)] + + _PEER_DO = [torch.empty_like(do) for _ in range(2)] + + return _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO + +def reset_global_memory_buffer(): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO + _PEER_Q = None + _PEER_K = None + _PEER_V = None + _PEER_M = None + _PEER_L = None + _PEER_O = None + + _DELTA_DQ = None + _PEER_L = None + _DELTA_DK = None + _DELTA_DV = None + _DK_DELTA_FROM_PEER = None + _DV_DELTA_FROM_PEER = None + _PEER_DO = None + +# Pytorch defers the creation of nccl communicators to the first P2P call, +# We manually create them so the first isend does not hang without an irecv. +# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138 +# Only support even number of GPUs. +def create_nccl_communicators(): + seq_rank = get_sequence_parallel_rank() + seq_group = get_sequence_parallel_group() + + empty_tensor = torch.empty(1,).cuda() + empty_tensor_2 = torch.empty(1,).cuda() + if torch.distributed.get_rank() % 2 == 0: + # sender + op1 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) + op2 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) + #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + dist.batch_isend_irecv([op1, op2]) + else: + # receiver + op1 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) + op2 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) + #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + handles = dist.batch_isend_irecv([op1, op2]) + #req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group) + dist.all_reduce(empty_tensor, group=seq_group) + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + #global _SEQUENCE_PARALLEL_GROUP + assert ( + _SEQUENCE_PARALLEL_GROUP is not None + ), 'sequence parallel group is not initialized' + return _SEQUENCE_PARALLEL_GROUP + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + return torch.distributed.get_rank(group=get_sequence_parallel_group()) + +def get_sequence_parallel_size(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_SIZE + if _SEQUENCE_PARALLEL_SIZE is not None: + return _SEQUENCE_PARALLEL_SIZE + return torch.distributed.get_world_size(group=get_sequence_parallel_group()) + +def destroy_sequence_parallel(): + """Set the groups to none.""" + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None + +# whether this is the last time the kernel being called +def is_last_time(time_step): + # e.g. on a 8-GPU setup: + # R=0: 0 + # R=1: 1 + # R=2: 2 + # R=3: 3 + # R=4: 4, 5, 6, 7 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank <= seq_world_size // 2: # no one helps these ranks + rank_finish_time = seq_rank + else: + rank_finish_time = seq_world_size // 2 + return rank_finish_time == time_step + +# Whether the current time step is computing for local q +def is_compute_for_local_query(time_step): + # R=3,4,5,6,7: Yes + # R=0: 0 + # R=1: 0, 1 + # R=2: 0, 1, 2 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank >= min(seq_world_size // 2, time_step): + return True + return False + +# Whether the current time step is idle +def is_idle(time_step): + # 0, 1, 2, 3: 4 + # 4, 5, 6, 7: No + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2: + return True + return False + +# Whether the current time step needs to synchronize with a remote computed result +def is_sync_from_remote(time_step): + # R=0, 1, 2, 3, 4: No + # R=5: 4 + # R=6: 3, 4 + # R=7: 2, 3, 4 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank > max(seq_world_size // 2, seq_world_size - time_step): + return True + return False + +def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, + k: torch.Tensor, peer_k: torch.Tensor, + v: torch.Tensor, peer_v: torch.Tensor, + o_stats: list,# peer_o_stats: list, + time_step: int, comm_mode, debug=False) -> torch.Tensor: + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] + + # Handles for operations that actually need to be wait before going to the next iteration. + # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; + all_handles = [] + # KV logic: different than older version, every rank to send/recv its own kv, + # to balance communication. In a balanced communication, every step each rank + # should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when + # time step > 0, should send its own kv and send/recv qo. In the older version, + # rank 0 does not send its kv, and rely on a later rank to pass it, where the + # later rank has to (1) receive kv, send rank 0's kv and send/recv qo. + # Q (load balancing) logic: semantically, this will be "%" world size, so + # the same send/recv rank as KV. Note: Only support even number of machines. + # O (load balancing) logic: rank 0 sends result to rank 7 at time 1. + # It get delayed for one time step, and thus has different maybe_send/recv_rank. + # Use (time_step + 1) to easily convert to synchornize version. + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _debug_send = _fwd_send_volume + _debug_recv = _fwd_recv_volume + + if maybe_send_rank >= seq_world_size: + #send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + #print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") + #q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(q) * q.element_size() + else: + # send kv + #print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") + #kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) + #kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank + seq_offset, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(k) * k.element_size() + _fwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") + #q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + else: + # recv kv + #print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") + #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) + #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank + seq_offset, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + maybe_send_rank_o = seq_rank - (time_step - 1) + maybe_recv_rank_o = seq_rank + (time_step - 1) + if maybe_send_rank_o < 0 and time_step > 1: + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") + #o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size + seq_offset, group=seq_group)) + if debug: + _fwd_send_volume += torch.numel(t) * t.element_size() + if maybe_recv_rank_o >= seq_world_size and time_step > 1 : + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") + #o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size + seq_offset, group=seq_group)) + if debug: + _fwd_recv_volume += torch.numel(t) * t.element_size() + + #reqs = [] + + if debug: + if seq_rank in [0, 8]: + print(f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB") + #return reqs + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs] + +# delta: may be you are using it for your local compute or as a distributed buffer to send to others +# .. Sorry for the bad naming.. +def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, + dv_delta: torch.Tensor, dk_delta_from_peer: torch.Tensor, + dv_delta_from_peer: torch.Tensor, q: torch.Tensor, + peer_q: torch.Tensor, L: torch.Tensor, + peer_L: torch.Tensor, k: torch.Tensor, + peer_k: torch.Tensor, v: torch.Tensor, + peer_v: torch.Tensor, o: torch.Tensor, + peer_o: torch.Tensor, do: torch.Tensor, + peer_do: torch.Tensor, time_step: int, comm_mode, debug=False): + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] + + all_handles = [] + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if maybe_send_rank >= seq_world_size: + #send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=L, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=o, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=do, peer=maybe_send_rank % seq_world_size + seq_offset, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(q) * q.element_size() + _bwd_send_volume += torch.numel(L) * L.element_size() + _bwd_send_volume += torch.numel(o) * o.element_size() + _bwd_send_volume += torch.numel(do) * do.element_size() + else: + # send kv + all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank + seq_offset, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(k) * k.element_size() + _bwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_L, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_o, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_do, peer=maybe_recv_rank % seq_world_size + seq_offset, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() + _bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size() + _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() + else: + # recv kv + all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank + seq_offset, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + # Whether I should update dq, dk and dv after waiting these requests + is_update_dq = False + is_update_dkv = False + + maybe_send_rank_dqkv = seq_rank - (time_step - 1) + maybe_recv_rank_dqkv = seq_rank + (time_step - 1) + + if time_step > 1: + if maybe_send_rank_dqkv < 0: + #print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") + all_handles.append(P2POp(op=isend, tensor=dq_delta, peer=maybe_send_rank_dqkv % seq_world_size + seq_offset, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + #print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank_dqkv + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank_dqkv + seq_offset, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + + if maybe_recv_rank_dqkv >= seq_world_size: + #print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") + all_handles.append(P2POp(op=irecv, tensor=dq_delta, peer=maybe_recv_rank_dqkv % seq_world_size + seq_offset, group=seq_group)) + is_update_dq = True + if debug: + _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + #print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") + all_handles.append(P2POp(op=irecv, tensor=dk_delta_from_peer, peer=maybe_recv_rank_dqkv + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta_from_peer, peer=maybe_recv_rank_dqkv + seq_offset, group=seq_group)) + is_update_dkv = True + if debug: + _bwd_recv_volume += torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() + _bwd_recv_volume += torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size() + + # return [], is_update_dq, is_update_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs], is_update_dq, is_update_dkv + +def maybe_send_recv_bwd_last_dkv(dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False): + is_update_last_dkv = False + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + seq_offset = get_process_group_ranks(seq_group)[0] + + if seq_world_size == 1: return [], is_update_last_dkv + + all_handles = [] + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if time_step == seq_world_size // 2: + maybe_send_rank = seq_rank - time_step + maybe_recv_rank = seq_rank + time_step + + assert (maybe_send_rank >= 0) ^ (maybe_recv_rank < seq_world_size), "R={seq_rank} should be either sending or receiving dkv in the last time step." + + if maybe_send_rank >= 0: + # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") + all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank + seq_offset, group=seq_group)) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + if maybe_recv_rank < seq_world_size: + # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") + all_handles.append(P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank + seq_offset, group=seq_group)) + all_handles.append(P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank + seq_offset, group=seq_group)) + if debug: + _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() + is_update_last_dkv = True + + # return [], is_update_last_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + + return [all_reqs], is_update_last_dkv + +def print_and_reset_comm_stats(): + seq_rank = get_sequence_parallel_rank() + + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _fwd_send_volume *= 1e-9 + _fwd_recv_volume *= 1e-9 + _bwd_send_volume *= 1e-9 + _bwd_recv_volume *= 1e-9 + + print(f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB.") + _fwd_send_volume = 0 + _fwd_recv_volume = 0 + _bwd_send_volume = 0 + _bwd_recv_volume = 0 + +def launch_async_handles(handles, comm_mode): + global _args + if comm_mode == "nocomm": + #print("skipping communication for ablation") + return [] + if len(handles) > 0: + return dist.batch_isend_irecv(handles) + return [] + +def wait_async_handles(reqs): + if len(reqs) > 0: + for req in reqs: + for r in req: + r.wait() \ No newline at end of file diff --git a/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py new file mode 100644 index 0000000000..d776495bc3 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn.py @@ -0,0 +1,743 @@ +import os +import math + +from einops import rearrange +import argparse + +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +#from torch.profiler import profile, record_function, ProfilerActivity +import functools +import triton +import triton.language as tl +import time +import numpy as np +from tqdm import tqdm + +try: + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +except: + pass + +from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, + launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, + maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m + m_ptrs = m + off_hz * N_CTX + offs_m + peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load(peer_o_block_ptr) + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16)) + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + m, + l, + O, + L, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + m_ptrs = m + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16)) + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + # Why do I have to change it from 128 64 to 32 32? + BLOCK_M = 32 + BLOCK_N = 32 + + bsz, nh, seq_len, hdim = q.shape + + m = torch.full((bsz * nh, seq_len), fill_value=-float("inf"), device=q.device, dtype=torch.float32) + l = torch.zeros_like(m) + L = torch.zeros_like(m) + o = torch.zeros_like(q) + + grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( + q, k, v, sm_scale, + m, + l, + o, + L, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + o.shape[0], o.shape[1], o.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4) + return q, k, v, o, L + +def _lightseq_backward(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine): + BLOCK = 128 + q, k, v, o, do = [rearrange(_x, 'b h s d -> b s h d').contiguous() for _x in [q, k, v, o, do]] + L = rearrange(L, '(b h) s -> b h s', b=q.shape[0]) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[2] + nkvh = k.shape[2] + is_gqa = (nqh > nkvh) + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ + peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) Immediate wait for abalation") + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + if backward_engine == "flash": + _flash_attn_backward(do, q, k, v, o, L, dq, dk, dv, 0.0, sm_scale, True, (-1,-1), None, False) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + else: + if backward_engine == "flash": + _flash_attn_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + pass + else: + if backward_engine == "flash": + _flash_attn_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) + else: + inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [rearrange(_x, 'b h s d -> b s h d') for _x in [dq, dk, dv]] + return dq, dk, dv + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = 'lightseq' + backward_engine = 'flash' + + q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode) + + ctx.save_for_backward(q, k, v, o, L) + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L = ctx.saved_tensors + sm_scale = ctx.sm_scale + + dq, dk, dv = _lightseq_backward(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine) + return dq, dk, dv, None, None + +attention = _attention.apply + + +#@pytest.mark.parametrize('causal', [False, True]) +#@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #f" {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1,2)) + flash_ref_out = flash_ref_out.transpose(1,2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1,2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1,2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1,2) + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg='provider', + line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} +) for mode in ["all"] for causal in [True]] + +# @triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): + assert mode == "all" #mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: attention(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") + + #return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="benchmark") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + else: + try: + import xformers.ops + from xformers.ops.fmha.common import Inputs, Context + from xformers.ops.fmha import _memory_efficient_attention_backward + from xformers.ops.fmha import cutlass, flash + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [2048, 4096]: + test_op(1, 16, N_CTX, 128, True) + #test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py new file mode 100644 index 0000000000..388ecd4c81 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py @@ -0,0 +1,772 @@ +import os +import math + +from einops import rearrange +import argparse + +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +#from torch.profiler import profile, record_function, ProfilerActivity + +import triton +import triton.language as tl +import time +import numpy as np +from tqdm import tqdm + +try: + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward +except: + pass + +from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, + launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, + maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + seqlen_q_rounded, seqlen_peer_q_rounded, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load(peer_o_block_ptr)#, boundary_check=(0, 1), padding_option='zero') + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) #, boundary_check=(0, 1), padding_option='zero') + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + m, + l, + O, + L, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + seqlen_q_rounded, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + # (TODO): Why float32? + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero') + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option='zero') + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero') + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * seqlen_q_rounded + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + #print("*"*100, dkv.shape, bs, slen, nkvh, n_rep, hdim) + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + #print("-"*100, dkv_reshape.shape, bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + # assert Lq == Lk and Lk == Lv + # assert Lk in {16, 32, 64, 128} + BLOCK_M = 128 + BLOCK_N = 64 + + bsz, nh, unpadded_seq_len, hdim = q.shape + cu_seq_lens = torch.arange(0, (bsz+1) * unpadded_seq_len, unpadded_seq_len, dtype=torch.int32, device=q.device) + max_seqlen = unpadded_seq_len + seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M + + m = torch.full((bsz * nh, seqlen_q_rounded), fill_value=-float("inf"), device=q.device, dtype=torch.float32) + l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.zeros_like(q) + + grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( + q, k, v, sm_scale, + m, + l, + o, + L, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + seqlen_q_rounded, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): +# print(f"t={time_step}: (Comp) R={seq_rank} sync with other - last time: {is_last_time(time_step)}") + seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1] + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + o.shape[0], o.shape[1], o.shape[2], + seqlen_q_rounded, seqlen_peer_q_rounded, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4) + return q, k, v, o, L, cu_seq_lens, max_seqlen + +def _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine, cu_seq_lens, max_seqlen): + BLOCK = 128 + L = rearrange(L[:, :max_seqlen].contiguous(), '(b h) s -> b h s', b=q.shape[0]) + q, k, v, o, do = [rearrange(_x, 'b h s d -> (b s) h d').contiguous() for _x in [q, k, v, o, do]] + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[1] + nkvh = k.shape[1] + is_gqa = (nqh > nkvh) + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ + peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) + if comm_mode == "sync": + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(do, q, k, v, o, L, dq, dk, dv, cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, True, None) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + else: + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) + else: + inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + # print(f"BWD t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"BWD t={time_step}: (Comp) R={seq_rank} helps other") + assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) + else: + inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) + op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) + grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [rearrange(_x, '(b s) h d -> b h s d', s=max_seqlen) for _x in [dq, dk, dv]] + return dq, dk, dv + +class _attention_varlen(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = 'lightseq' + backward_engine = 'flash' + + q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode) + + ctx.save_for_backward(q, k, v, o, L, cu_seq_lens) + ctx.max_seqlen = max_seqlen + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L, cu_seq_lens = ctx.saved_tensors + sm_scale = ctx.sm_scale + max_seqlen = ctx.max_seqlen + + dq, dk, dv = _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine, cu_seq_lens, max_seqlen) + return dq, dk, dv, None, None + +dist_attn_varlen = _attention_varlen.apply + + +#@pytest.mark.parametrize('causal', [False, True]) +#@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + rank = dist.get_rank() + world_size = dist.get_world_size() + + + PAD = world_size * 256 + seq_per_rank = (N_CTX-PAD) // world_size + q = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + # DEBUG: mask out + #mask = torch.zeros(Z, H, seq_per_rank * (world_size - 1), D_HEAD).cuda() + #mask_2 = torch.ones(Z, H, seq_per_rank, D_HEAD).cuda() + #mask = torch.cat((mask, mask_2), dim=-2).to(dtype) + #q = mask * q + #k = mask * k + #v = mask * v + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX-PAD, N_CTX-PAD), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f"rank {rank} fails backward dq" #{ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dq} {torch.max(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dq)} rank {rank} fails backward dk" + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #{ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv" #{ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#TODO(High Priority): Investigate why rank 0 tends to have larger numerical difference. +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + #print(q.shape, ref_k.shape, k.shape) + p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + #print("Before reduce", ref_dv.shape) + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) + #print("After reduce", ref_dv.shape) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1,2)) + flash_ref_out = flash_ref_out.transpose(1,2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1,2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1,2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1,2) + + # triton implementation + + a, b, c, d = q.size() + real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) + real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" + #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +#BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg='provider', + line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} +) for mode in ["all"] for causal in [True]] + +# @triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): + assert mode == "all" #mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: dist_attn_varlen(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") + + #return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="test") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + else: + try: + import xformers.ops + from xformers.ops.fmha.common import Inputs, Context + from xformers.ops.fmha import _memory_efficient_attention_backward + from xformers.ops.fmha import cutlass, flash + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [4096]: + test_op(2, 16, N_CTX, 128, True) + #test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py new file mode 100644 index 0000000000..6903c5f2bc --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -0,0 +1,719 @@ +""" +Materialization-aware gradient checkpointing monkey patch. +""" +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging + +from einops import rearrange + +from .lightseq_async_attn import _lightseq_forward, _lightseq_backward +from .async_communication import initialize_distributed, reset_global_memory_buffer +from transformers.cache_utils import Cache +from .offload_buffer import offload_buffer, OffloadBuffer + +# define a global buffer to save flash attention outputs +# it's called global because it saves the outputs for all layers +global_flash_attn_out_buffer = None + +# define a local buffer to save recomputed qkv +# it's called local because it's a temporary buffer which will be updated across layers +local_res_grad_buffer = None + +# hooks for the gradients of residual +global_hooks = [] + +logger = logging.get_logger(__name__) +def init_flash_attn_buffers(num_layers): + # update the global buffer according to number of layers + global global_flash_attn_out_buffer + global_flash_attn_out_buffer = [None] * num_layers + +def clean_hook(): + # Remove all hooks in the global buffer + for hook in global_hooks: + hook.remove() + # Clear the global buffer + global_hooks.clear() + +def clear_all_buffers_at_the_end_of_training(): + # call it at the end of training + global global_flash_attn_out_buffer + global_flash_attn_out_buffer = None + global local_res_grad_buffer + local_res_grad_buffer = None + clean_hook() + +def save_flash_attn_out_to_global_buffer(idx, out): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = out + +def get_flash_attn_out_from_global_buffer(idx): + global global_flash_attn_out_buffer + return global_flash_attn_out_buffer[idx] + +def free_flash_attn_out_buffer(idx): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = None + +def write_gradient_to_flash_attn_out(idx, grad): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx].grad = grad + +def save_res_grad_hook(grad): + global local_res_grad_buffer + local_res_grad_buffer = grad + +def load_and_add_res_grad_hook(grad): + grad += get_res_grad_from_local_buffer() + +def get_res_grad_from_local_buffer(): + global local_res_grad_buffer + assert local_res_grad_buffer is not None + return local_res_grad_buffer + +class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): + """ Avoid doing twice flash attention forward during checkpointed backward. + args: + hidden_states, # i.e., flash attention output which is saved in global buffer. + attention_mask, + position_ids, + residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. + """ + + @staticmethod + def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.layer_idx = layer_idx + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + global offload_buffer + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_indices_dict = {} + tensor_inputs = [] + + hidden_state = None + position_ids = None + for i, arg in enumerate(args): + if i == 0 and ctx.layer_idx != 0: + # flash attention output is saved to the global buffer during forward + ctx.inputs.append(None) + else: + if torch.is_tensor(arg): + if offload_buffer.enable_offload: + if len(arg.shape) == 3: + hidden_state = arg + ctx.tensor_indices_dict[i] = 'hidden_state' + elif len(arg.shape) == 2: + position_ids = arg + ctx.tensor_indices_dict[i] = 'position_ids' + else: + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + with torch.no_grad(): + q, k, v, residual = run_function(*args) + softmax_scale = q.shape[-1] ** (-0.5) + + # lightseq version + _, _, _, out, softmax_lse = _lightseq_forward(q, k, v, True, softmax_scale, comm_mode='lightseq') + rng_state = None + # save flash attention output to global buffer + if offload_buffer.enable_offload: + offload_buffer.save_flash_attn_out(ctx.layer_idx, out) + offload_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids) + ctx.save_for_backward(softmax_lse) + else: + save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + tensor_inputs += [softmax_lse] + ctx.save_for_backward(*tensor_inputs) + + ctx.softmax_scale = softmax_scale + + return out, residual + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument.") + # Copy the list to avoid modifying original list. + global offload_buffer + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_indices_dict = ctx.tensor_indices_dict + tensors = ctx.saved_tensors + if offload_buffer.enable_offload: + softmax_lse = tensors[0] + hidden_state, position_ids = offload_buffer.get_hidden_states(ctx.layer_idx) + if ctx.layer_idx > 0: + inputs[0] = offload_buffer.get_flash_attn_out(ctx.layer_idx-1) + for k, v in tensor_indices_dict.items(): + if v == 'hidden_state': + inputs[k] = hidden_state + if v == 'position_ids': + inputs[k] = position_ids + else: + tensors, softmax_lse = tensors[:-1], tensors[-1] + if ctx.layer_idx > 0: + inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + # Stop recomputation before flash attention + # It is unecessary to run recomputation for flash attn + q, k, v, residual = ctx.run_function(*detached_inputs) + + # run backward() with only tensor that requires grad + # run flash attention backward first: + # get 'dout' from auto_grad inputs + # get 'out' from global buffer + # get 'qkv' from the recomputed tensors + #dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) + #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) + #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) + # out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + if offload_buffer.enable_offload: + out = offload_buffer.get_flash_attn_out(ctx.layer_idx) + else: + out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + + # todo get dout + dout = args[0] + + # lightseq version + dq, dk, dv = _lightseq_backward(dout, q, k, v, out, softmax_lse, ctx.softmax_scale, comm_mode='lightseq', backward_engine='flash') + #dqkv = torch.stack([dq, dk, dv]) + + # run backward for the part before flash attention + #qkv.backward(dqkv) + torch.autograd.backward([q, k, v], [dq, dk, dv]) + + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + # write flash attention output gradients to buffer + if offload_buffer.enable_offload: + if ctx.layer_idx > 0: + offload_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad + offload_buffer.free_layer_gpu_buffer(ctx.layer_idx) + else: + if ctx.layer_idx > 0: + write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) + return (None, None, None) + grads + + +def checkpoint_end_with_flash_attention(function, layer_idx, *args, use_reentrant: bool = True, **kwargs): + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs and use_reentrant: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + return CheckpointFunctionEndWithFlashAttention.apply(function, layer_idx, preserve, *args) + + +class CheckpointFunctionLastModule(torch.autograd.Function): + """ + for the last ffn layer after flash attention, modifications include: + write the gradients wrt flash attention output and residual to the global buffer. + """ + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + + assert torch.is_tensor(args[0]), "assuming the first tensor is the flash attention output" + for i, arg in enumerate(args): + if torch.is_tensor(arg) and i == 0: + # flash attn output has been saved to global buffer + ctx.inputs.append(None) + elif torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + global offload_buffer + # Fill in inputs with appropriate saved tensors. + # Fill the flash attention output first + # inputs[0] should be flash attention output + if offload_buffer.enable_offload: + inputs[0] = offload_buffer.get_flash_attn_out(-1) + else: + inputs[0] = get_flash_attn_out_from_global_buffer(-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + # write flash attention output gradients to buffer + if offload_buffer.enable_offload: + offload_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad + else: + write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) + return (None, None) + grads + +def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs and use_reentrant: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + return CheckpointFunctionLastModule.apply(function, preserve, *args) + + +def llama_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + compute_attn_only: Optional[bool] = False, + compute_ffn_only: Optional[bool] = False, + residual: Optional[bool] = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + assert compute_ffn_only or compute_attn_only + + if compute_attn_only: + residual = hidden_states + + if residual.requires_grad: + # register a hook to add the gradient of residual + # from next checkpoint layer when doing recomputation + hook = residual.register_hook(load_and_add_res_grad_hook) + global_hooks.append(hook) + + hidden_states = self.input_layernorm(hidden_states) + + # Flash Attention + bsz, q_len, _ = hidden_states.size() + try: + query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) + value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) + except: + # old transformers versions don't support num_key_value_heads + query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.self_attn.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + return query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), residual + + elif compute_ffn_only: + hidden_states = self.self_attn.o_proj(rearrange(hidden_states, 'b h s d -> b s (h d)')) + # Need to add residual here to make sure checkpoint is right after attention + if residual.requires_grad: + # save the gradient of residual to the local buffer + # collect the hooks which should be removed after backward to avoid memory leak + hook = residual.register_hook(save_res_grad_hook) + global_hooks.append(hook) + + hidden_states = residual + hidden_states + + # Fully Connected + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + else: + raise AttributeError + + return outputs + + +def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, +): + assert cache_position is None, "cache_position is not supported" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + attention_mask = None + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + try: + logger.warning_once( + "***** Using fast gradient checkpointing... *****" + ) + except: + pass + # initialize the global buffer + # init_flash_attn_buffers(len(self.layers)) + global offload_buffer + if offload_buffer.enable_offload: + offload_buffer.allocate( + self.config.num_hidden_layers, + shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads] + ) + else: + init_flash_attn_buffers(len(self.layers)) + if use_cache: + try: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + except: + pass + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # apply flash-attention friendly gradient checkpointing + if self.gradient_checkpointing and self.training: + for idx in range(len(self.layers) + 1): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + def forward_first_attn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, _ = inputs + # None for past_key_value + return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) + return custom_forward + + def forward_ffn_attn_layer(module1, module2): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + layer_outputs = module1(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) + hidden_states = layer_outputs[0] + return module2(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) + return custom_forward + + def forward_last_ffn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) + return custom_forward + + if idx == 0: + layer_outputs = checkpoint_end_with_flash_attention( + forward_first_attn_module(self.layers[0]), + idx, + hidden_states, + attention_mask, + position_ids, + None, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + elif idx == len(self.layers): + layer_outputs = checkpoint_last_module( + forward_last_ffn_module(self.layers[-1]), + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states = layer_outputs[0] + else: + layer_outputs = checkpoint_end_with_flash_attention( + forward_ffn_attn_layer(self.layers[idx-1], self.layers[idx]), + idx, + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + else: + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits.float() + + loss = None + if labels is not None and hasattr(self, 'loss_function'): + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def apply_dist_flash_attn_monkey_patch_llama(sp_size=None, enable_offload=False, offload_percent=0.): + initialize_distributed(sp_size=sp_size) + global offload_buffer + offload_buffer = OffloadBuffer(enable_offload=enable_offload, offload_percent=offload_percent) + transformers.models.llama.modeling_llama.LlamaModel.forward = forward + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward + transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = llama_model_forward diff --git a/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py new file mode 100644 index 0000000000..ed20759ab0 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/offload_buffer.py @@ -0,0 +1,106 @@ +import cuda +import cuda.cudart +import torch + +class OffloadBuffer: + + def __init__(self, enable_offload, offload_percent): + self.enable_offload = enable_offload + self.offload_percent = offload_percent + self.allocated = False + + def allocate(self, num_layers, shape): + if self.allocated: + return + self.layer_num = num_layers + self.cpu_layer_num = int(num_layers * self.offload_percent) + self.gpu_layer_num = num_layers - self.cpu_layer_num + self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.cpu_buffer = [torch.empty(shape, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] + bs, num_heads, seq_len, emb_size = shape + shape_h = [bs, seq_len, num_heads * emb_size] + shape_a = [bs, seq_len] + self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.hidden_state_cpu_buffer = [torch.empty(shape_h, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] + self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.position_id_cpu_buffer = [torch.empty(shape_a, dtype=torch.bfloat16, pin_memory=True) for _ in range(self.cpu_layer_num)] + _, self.d2h_stream = cuda.cudart.cudaStreamCreate() + self.h2d_streams = [] + for i in range(self.gpu_layer_num): + _, h2d_stream = cuda.cudart.cudaStreamCreate() + self.h2d_streams.append(h2d_stream) + self.allocated = True + + def save_flash_attn_out(self, layer_idx, out): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx < self.cpu_layer_num: + _ = cuda.cudart.cudaMemcpyAsync(self.cpu_buffer[layer_idx].data_ptr(), out.data_ptr(), out.nelement() * out.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.gpu_buffer[idx] = out + + def save_hidden_states(self, layer_idx, *hs): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + hidden_state = hs[0] + position_id = hs[1] + if layer_idx < self.cpu_layer_num: + _ = cuda.cudart.cudaMemcpyAsync(self.hidden_state_cpu_buffer[layer_idx].data_ptr(), hidden_state.data_ptr(), hidden_state.nelement() * hidden_state.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) + _ = cuda.cudart.cudaMemcpyAsync(self.position_id_cpu_buffer[layer_idx].data_ptr(), position_id.data_ptr(), position_id.nelement() * position_id.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.hidden_state_gpu_buffer[idx] = hidden_state + self.position_id_gpu_buffer[idx] = position_id + + def get_flash_attn_out(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + _ = cuda.cudart.cudaStreamSynchronize(self.h2d_streams[idx]) + return self.gpu_buffer[idx] + + def get_hidden_states(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + _ = cuda.cudart.cudaStreamSynchronize(self.h2d_streams[idx]) + return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] + + def free_layer_gpu_buffer(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx == self.layer_num - 1: + _ = cuda.cudart.cudaStreamSynchronize(self.d2h_stream) + cpu_layer_idx = layer_idx - self.gpu_layer_num + if layer_idx >= self.cpu_layer_num: + idx = layer_idx - self.cpu_layer_num + else: + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num + self.gpu_buffer[idx].grad = None + if cpu_layer_idx < 0: + self.gpu_buffer[idx] = None + self.hidden_state_gpu_buffer[idx] = None + self.position_id_gpu_buffer[idx] = None + return + cb = self.cpu_buffer[cpu_layer_idx] + hcb = self.hidden_state_cpu_buffer[cpu_layer_idx] + pcb = self.position_id_cpu_buffer[cpu_layer_idx] + gb = self.gpu_buffer[idx] + hgb = self.hidden_state_gpu_buffer[idx] + pgb = self.position_id_gpu_buffer[idx] + _ = cuda.cudart.cudaMemcpyAsync(gb.data_ptr(), cb.data_ptr(), gb.nelement() * gb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + _ = cuda.cudart.cudaMemcpyAsync(hgb.data_ptr(), hcb.data_ptr(), hgb.nelement() * hgb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + _ = cuda.cudart.cudaMemcpyAsync(pgb.data_ptr(), pcb.data_ptr(), pgb.nelement() * pgb.element_size(), cuda.cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.h2d_streams[idx]) + + def __del__(self): + if self.allocated: + cuda.cudart.cudaStreamDestroy(self.d2h_stream) + for i in range(self.gpu_layer_num): + cuda.cudart.cudaStreamDestroy(self.h2d_streams[i]) + +offload_buffer = None diff --git a/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py new file mode 100644 index 0000000000..ed081a32e4 --- /dev/null +++ b/src/llamafactory/easy_context/dist_flash_attn/prepare_input.py @@ -0,0 +1,72 @@ + + +def extract_local(value, rank, world_size, device, dim=1): + value_local = value.chunk(world_size, dim=dim)[rank] + if device == None: + return value_local + return value_local.to(device) + + +def prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_dist_flash_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": local_attention_mask, + "position_ids": local_position_ids, + "labels": local_labels, + } diff --git a/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py new file mode 100644 index 0000000000..0ba10141b8 --- /dev/null +++ b/src/llamafactory/easy_context/ulysses_attn/monkey_patch.py @@ -0,0 +1,128 @@ +import transformers +from typing import List, Optional, Tuple, Union +import warnings +import torch +import torch.utils.checkpoint +from yunchang.ulysses import UlyssesAttention + +ulysses_attn = None + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = ulysses_attn( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # assert isinstance( + # self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + # ) or isinstance( + # self.self_attn, + # transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + # ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +def get_sp_process_group(sequence_parallel_size=None): + if sequence_parallel_size is None: + return None + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") + if sequence_parallel_size is None or sequence_parallel_size == -1: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + rank = torch.distributed.get_rank() + + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + if rank in ranks: + group = torch.distributed.new_group(ranks) + return group + +def apply_ulysses_attn_monkey_patch_llama(sp_size=None): + sp_group = get_sp_process_group(sp_size) + global ulysses_attn + ulysses_attn = UlyssesAttention(sequence_process_group=sp_group) + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) + + diff --git a/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py b/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py new file mode 100644 index 0000000000..de6a0b50b4 --- /dev/null +++ b/src/llamafactory/easy_context/ulysses_attn/prepare_inputs.py @@ -0,0 +1,80 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + dimension_size = value.shape[dim] + sub_seq_length = dimension_size // world_size + + sub_seq_start = rank * sub_seq_length + sub_seq_end = (rank + 1) * sub_seq_length + local_value = value[:, sub_seq_start:sub_seq_end] + if device == None: + return local_value + return local_value.to(device) + + +def prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_ulysses_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": local_attention_mask, + "position_ids": local_position_ids, + "labels": local_labels, + } \ No newline at end of file diff --git a/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py b/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py new file mode 100644 index 0000000000..fb509e0ef2 --- /dev/null +++ b/src/llamafactory/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py @@ -0,0 +1,94 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformers +import inspect + + +class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): + """ + Saves VRAM by smartly offloading to RAM. + Tiny hit to performance, since we mask the movement via non blocking calls. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + + return output + + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad = True + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + + pass + + +pass + + +def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + assert gradient_checkpointing_kwargs == None + if not self.supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + + gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = ( + "value" in inspect.signature(self._set_gradient_checkpointing).parameters + ) + + if not _is_using_old_format: + self._set_gradient_checkpointing( + enable=True, gradient_checkpointing_func=gradient_checkpointing_func + ) + else: + raise NotImplementedError() + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + +def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): + transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( + new_gradient_checkpointing_enable + ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py new file mode 100644 index 0000000000..250c1ea6f1 --- /dev/null +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -0,0 +1,229 @@ +import transformers +from typing import List, Optional, Tuple, Union +import warnings +import torch +import torch.utils.checkpoint +from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func +from functools import partialmethod, partial +import inspect + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + group=None +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + group=group + ) + + return attn_output + +def new_flash_attn_forward_v2( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal, + dropout=0.0, + position_ids=None, + softmax_scale=None, + sliding_window=None, + use_top_left_mask=False, + softcap=None, + group=None +): + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert sliding_window is None + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + group=group + ) + + return attn_output + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +def new_decoder_forward_v2( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +def get_sp_process_group(sequence_parallel_size=None): + if sequence_parallel_size is None: + return None + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + print(f"sequence_parallel_size is {sequence_parallel_size}, world_size is {world_size}") + if sequence_parallel_size is None or sequence_parallel_size == -1: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + rank = torch.distributed.get_rank() + + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + if rank in ranks: + group = torch.distributed.new_group(ranks) + return group + +def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): + sp_group = get_sp_process_group(sp_size) + if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + partialmethod(new_flash_attn_forward, group=sp_group) + ) + else: + transformers.models.llama.modeling_llama._flash_attention_forward = ( + partial(new_flash_attn_forward_v2, group=sp_group) + ) + if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward_v2 + ) + else: + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py new file mode 100644 index 0000000000..6d2925aa41 --- /dev/null +++ b/src/llamafactory/easy_context/zigzag_ring_attn/prepare_inputs.py @@ -0,0 +1,76 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat( + [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim + ) + if device == None: + return local_value + return local_value.to(device) + + +def prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } + +def prepare_zigzag_ring_attn_sft_inputs( + input_ids, attention_mask, position_ids, labels, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + local_attention_mask = extract_local( + attention_mask, + rank, + world_size, + device + ) + local_labels = extract_local( + labels, + rank, + world_size, + device, + ) + return { + "input_ids": local_input_ids, + "attention_mask": None, + "position_ids": local_position_ids, + "labels": local_labels, + } \ No newline at end of file diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 430b8a48bb..a281ff12d2 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -16,7 +16,7 @@ def __init__(self, output_dir: str) -> None: formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" ) - self.setLevel(logging.INFO) + self.setLevel(logging.DEBUG) self.setFormatter(formatter) os.makedirs(output_dir, exist_ok=True) @@ -53,7 +53,7 @@ def get_logger(name: str) -> logging.Logger: handler.setFormatter(formatter) logger = logging.getLogger(name) - logger.setLevel(logging.INFO) + logger.setLevel(logging.DEBUG) logger.addHandler(handler) return logger diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index facbe792ca..5f9889c82b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -312,6 +312,24 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) + parallel_mode: Literal["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"] = field( + default="data_parallel", + metadata={"help": "which sequence parallel mode to use."}, + ) + sp_size: int = field( + default=-1, + metadata={ + "help": "allow using seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" + }, + ) + sp_enable_offload: bool = field( + default=False, + metadata={"help": "whether enable offload activation to cpu for dist_flash_attn"}, + ) + sp_offload_percent: float = field( + default=0.0, + metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"} + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ec5dd62c59..5e7b9abd1a 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -294,7 +294,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: str(model_args.compute_dtype), ) ) - + logger.info(f"seed is:{training_args.seed}") transformers.set_seed(training_args.seed) return model_args, data_args, training_args, finetuning_args, generating_args diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 697a04e77c..ad4f7dbc95 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, LlamaForCausalLM from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1d96e82f63..d357caf6cd 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -5,7 +5,13 @@ from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler - +import torch +from torch.utils.data import DataLoader +from transformers.utils import is_datasets_available +from transformers.trainer_utils import seed_worker +import datasets +from torch.nn import CrossEntropyLoss +import os if TYPE_CHECKING: import torch @@ -49,3 +55,162 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, if self.processor is not None: output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) + +class CustomSeqParallelTrainer(CustomTrainer): + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + if _is_peft_model(unwrapped_model): + model_name = unwrapped_model.base_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + if self.finetuning_args.parallel_mode== "data_parallel": + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + else: + sp_size = self.finetuning_args.sp_size + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + valid_label_cnt = (labels!=-100).sum(1)[None, :] + valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + n_gpus = valid_label_cnt_gather.shape[0] + if sp_size == -1: + sp_size = n_gpus + dp_rank = self.accelerator.process_index // sp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + bs = len(shift_labels) + loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) + for b in range(bs): + normalizer=valid_label_cnt_all[b].item() + loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer + loss = loss.mean()*sp_size + + return (loss, outputs) if return_outputs else loss + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size + return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def get_eval_dataloader(self, eval_dataset) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: + return self.accelerator.prepare(self._eval_dataloader) + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + self._eval_dataloader = eval_dataloader + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size + return eval_dataloader + return self.accelerator.prepare(eval_dataloader) \ No newline at end of file diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 8a6355674d..fc33acffae 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -5,12 +5,15 @@ from transformers import DataCollatorForLanguageModeling -from ...data import get_dataset, split_dataset +from ...data import get_dataset, split_dataset, SeqParallelDataCollatorForLanguageModeling from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push -from .trainer import CustomTrainer +from .trainer import CustomTrainer, CustomSeqParallelTrainer +import os +import torch +from ...easy_context import apply_seq_parallel_monkey_patch if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -29,10 +32,34 @@ def run_pt( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) + + # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"seq_len: {data_args.cutoff_len}") + + data_collator = SeqParallelDataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + seq_algo=finetuning_args.parallel_mode, + sp_size=finetuning_args.sp_size, + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + device=torch.device("cuda", local_rank) + ) # Initialize our Trainer - trainer = CustomTrainer( + # trainer = CustomTrainer( + # model=model, + # args=training_args, + # finetuning_args=finetuning_args, + # data_collator=data_collator, + # callbacks=callbacks, + # **tokenizer_module, + # **split_dataset(dataset, data_args, training_args), + # ) + + trainer = CustomSeqParallelTrainer( model=model, args=training_args, finetuning_args=finetuning_args, @@ -51,18 +78,18 @@ def run_pt( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - + # Evaluation - if training_args.do_eval: - metrics = trainer.evaluate(metric_key_prefix="eval") - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") + # if training_args.do_eval: + # metrics = trainer.evaluate(metric_key_prefix="eval") + # try: + # perplexity = math.exp(metrics["eval_loss"]) + # except OverflowError: + # perplexity = float("inf") - metrics["perplexity"] = perplexity - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + # metrics["perplexity"] = perplexity + # trainer.log_metrics("eval", metrics) + # trainer.save_metrics("eval", metrics) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index c063b214df..ca30bcc7a2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -10,7 +10,11 @@ from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler - +from torch.utils.data import DataLoader +from transformers.utils import is_datasets_available +from transformers.trainer_utils import seed_worker +import datasets +from torch.nn import CrossEntropyLoss if TYPE_CHECKING: from transformers import ProcessorMixin @@ -130,3 +134,191 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) + +class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + if not hasattr(self, 'compute_loss_func'): + self.compute_loss_func = None + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + if hasattr(self, 'model_accepts_loss_kwargs') and self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + if _is_peft_model(unwrapped_model): + model_name = unwrapped_model.base_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + # We don't use .loss here since the model may return tuples instead of ModelOutput. + if self.finetuning_args.parallel_mode== "data_parallel": + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if not hasattr(self.args, 'average_tokens_across_devices'): + self.args.average_tokens_across_devices = None + if not hasattr(self, 'model_accepts_loss_kwargs'): + self.model_accepts_loss_kwargs= None + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + else: + if num_items_in_batch is None: + sp_size = self.finetuning_args.sp_size + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + valid_label_cnt = (labels!=-100).sum(1)[None, :] + valid_label_cnt_gather = self.accelerator.gather(valid_label_cnt) + n_gpus = valid_label_cnt_gather.shape[0] + if sp_size == -1: + sp_size = n_gpus + dp_rank = self.accelerator.process_index // sp_size + valid_label_cnt_all =valid_label_cnt_gather[dp_rank * sp_size : (dp_rank+1) * sp_size].sum(0).detach() + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + bs = len(shift_labels) + loss = torch.zeros(bs, dtype=shift_logits.dtype, device=shift_labels.device) + for b in range(bs): + normalizer=valid_label_cnt_all[b].item() + loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer + loss = loss.mean()*sp_size + else: + assert self.args.average_tokens_across_devices is True, "must ensure average_tokens_across_devices if parallel_mode is not data_parallel" + loss_fn = CrossEntropyLoss(reduction='sum') + labels = inputs.pop("labels") + logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1] + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + loss = loss_fn(shift_logits, shift_labels)/num_items_in_batch + + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + + return (loss, outputs) if return_outputs else loss + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size + return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def get_eval_dataloader(self, eval_dataset) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: + return self.accelerator.prepare(self._eval_dataloader) + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": False, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + self._eval_dataloader = eval_dataloader + + if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size + return eval_dataloader + return self.accelerator.prepare(eval_dataloader) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f09b51730b..6aeda45187 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -5,13 +5,18 @@ from transformers import DataCollatorForSeq2Seq from ...data import get_dataset, split_dataset +from ...data.collator import SeqParallelDataCollator from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics -from .trainer import CustomSeq2SeqTrainer +from .trainer import CustomSeq2SeqTrainer, CustomSeqParallelTrainer + +import torch +import os +from ...easy_context import apply_seq_parallel_monkey_patch if TYPE_CHECKING: @@ -32,6 +37,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size, enable_offload=finetuning_args.sp_enable_offload, offload_percent=finetuning_args.sp_offload_percent) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -39,19 +45,26 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = DataCollatorForSeq2Seq( + local_rank = int(os.getenv("LOCAL_RANK")) + world_size = torch.distributed.get_world_size() + print(f"seq_len: {data_args.cutoff_len}") + data_collator = SeqParallelDataCollator( tokenizer=tokenizer, - pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention + pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + seq_algo=finetuning_args.parallel_mode, + sp_size=finetuning_args.sp_size, + rank=torch.distributed.get_rank(), + world_size=world_size, + device=torch.device("cuda", local_rank) ) - # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns # Initialize our Trainer - trainer = CustomSeq2SeqTrainer( + trainer = CustomSeqParallelTrainer( model=model, args=training_args, finetuning_args=finetuning_args, @@ -78,22 +91,22 @@ def run_sft( if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) - # Evaluation - if training_args.do_eval: - metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) - if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled - metrics.pop("eval_loss", None) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - # Predict - if training_args.do_predict: - predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) - if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled - predict_results.metrics.pop("predict_loss", None) - trainer.log_metrics("predict", predict_results.metrics) - trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results) + # # Evaluation + # if training_args.do_eval: + # metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + # if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + # metrics.pop("eval_loss", None) + # trainer.log_metrics("eval", metrics) + # trainer.save_metrics("eval", metrics) + + # # Predict + # if training_args.do_predict: + # predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + # if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + # predict_results.metrics.pop("predict_loss", None) + # trainer.log_metrics("predict", predict_results.metrics) + # trainer.save_metrics("predict", predict_results.metrics) + # trainer.save_predictions(predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/tianqing_examples/Llama3-70B-pt-dp.sh b/tianqing_examples/Llama3-70B-pt-dp.sh new file mode 100644 index 0000000000..25e607a8e7 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-dp.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode data_parallel \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B-pt-sp-lora.sh b/tianqing_examples/Llama3-70B-pt-sp-lora.sh new file mode 100644 index 0000000000..24f0cdf169 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-sp-lora.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type lora \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path ${DATA_PATH:-"/mnt/zj-gpfs/home/lsy/data/per_source_upsample_32769_common_5b"} \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 99999999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B-pt-sp.sh b/tianqing_examples/Llama3-70B-pt-sp.sh new file mode 100644 index 0000000000..955079a889 --- /dev/null +++ b/tianqing_examples/Llama3-70B-pt-sp.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/lsy/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$((${NGPUS}*${WORLD_SIZE})) +SEQ_LEN=${SEQ_LEN:-32768} +BATCH_SIZE=${BATCH_SIZE:-1} +ACC=${ACC:-4} +SAVE_PATH=${SAVE_PATH:-"./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5"} +SAVE_STEPS=${SAVE_STEPS:-500} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage pt \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--tokenized_path /mnt/zj-gpfs/home/lsy/data/tokenized_c4_demo \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1000 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ${SAVE_PATH} \ +--logging_steps 1 \ +--save_steps ${SAVE_STEPS} \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 9999 \ +--dataloader_drop_last diff --git a/tianqing_examples/Llama3-70B-sp-offload.sh b/tianqing_examples/Llama3-70B-sp-offload.sh new file mode 100644 index 0000000000..93475ced61 --- /dev/null +++ b/tianqing_examples/Llama3-70B-sp-offload.sh @@ -0,0 +1,60 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/root/model/Meta-Llama-3-70B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-32768} +SP_SIZE=${SP_SIZE:-1} +SP_OFFLOAD_PERCENT=${SP_OFFLOAD_PERCENT:-0.8} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode dist_flash_attn \ +--sp_size ${SP_SIZE} \ +--sp_enable_offload \ +--sp_offload_percent ${SP_OFFLOAD_PERCENT} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset long_sft_32k \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_steps 10 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 1.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 1000 diff --git a/tianqing_examples/Llama3-70B.sh b/tianqing_examples/Llama3-70B.sh new file mode 100644 index 0000000000..a3d9e409c1 --- /dev/null +++ b/tianqing_examples/Llama3-70B.sh @@ -0,0 +1,60 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-70B/"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-32768} +SP_SIZE=${SP_SIZE:-1} +BATCH_SIZE=${BATCH_SIZE:-1} +PARALLEL_MODE=${PARALLEL_MODE:-"dist_flash_attn"} +DATASET=${DATASET:-"long_sft_32k"} +ACC=${ACC:-4} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +export NCCL_DEBUG=WARN +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--parallel_mode ${PARALLEL_MODE} \ +--sp_size ${SP_SIZE} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset ${DATASET} \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_steps 10 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps ${ACC} \ +--learning_rate 2e-5 \ +--num_train_epochs 1.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--eval_steps 1000 diff --git a/tianqing_examples/Llama3-8B.sh b/tianqing_examples/Llama3-8B.sh new file mode 100644 index 0000000000..b93f5823f8 --- /dev/null +++ b/tianqing_examples/Llama3-8B.sh @@ -0,0 +1,59 @@ +# You can observe that the number of steps for different stage is quite different. They are not magic number. They are set to those numbers simply because I esitimate the time it takes to finish the training, and +# choose the number such that it fits my daily schedule>_<. This is for you to exactly reproduce my results. You many change the steps to other numbers if you want to. +MODEL_DIR=${MODEL_DIR:-"/mnt/zj-gpfs/home/qianhao/models/Meta-Llama-3-8B"} +NGPUS=${NGPUS:-8} +WORLD_SIZE=${WORLD_SIZE:-1} +NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] +SEQ_LEN=${SEQ_LEN:-1024} +SP_SIZE=${SP_SIZE:-1} +BATCH_SIZE=${BATCH_SIZE:-1} +export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' +export WANDB_DISABLED=true +echo ${RANK}/$[WORLD_SIZE] +if [ ${MASTER_ADDR} == 'localhost' ]; then + export MASTER_ADDR=`hostname -i` +fi +echo ${MASTER_ADDR}:${MASTER_PORT} + +accelerate launch \ +--config_file examples/accelerate/ds_multi_nodes.yaml \ +--use_deepspeed \ +--num_machines ${WORLD_SIZE} \ +--num_processes ${NUM_PROCESSES} \ +--main_process_ip ${MASTER_ADDR} \ +--main_process_port ${MASTER_PORT} \ +--machine_rank ${RANK} \ +--rdzv_conf "rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,rdzv_backend=c10d" \ +src/train.py \ +--model_name_or_path ${MODEL_DIR} \ +--stage sft \ +--do_train \ +--finetuning_type full \ +--lora_target all \ +--parallel_mode dist_flash_attn \ +--sp_size ${SP_SIZE} \ +--deepspeed examples/deepspeed/ds_z3_offload_config.json \ +--dataset alpaca_en \ +--template llama3 \ +--cutoff_len ${SEQ_LEN} \ +--max_samples 1200 \ +--overwrite_cache \ +--preprocessing_num_workers 16 \ +--output_dir ./output/8B_1K_bs_1_step_1000_lr_2e-5 \ +--logging_steps 1 \ +--save_steps 500 \ +--plot_loss \ +--overwrite_output_dir \ +--per_device_train_batch_size ${BATCH_SIZE} \ +--per_device_eval_batch_size ${BATCH_SIZE} \ +--gradient_accumulation_steps 4 \ +--learning_rate 2e-5 \ +--num_train_epochs 3.0 \ +--lr_scheduler_type cosine \ +--warmup_ratio 0.1 \ +--bf16 \ +--ddp_timeout 180000000 \ +--val_size 0.1 \ +--eval_strategy steps \ +--dataloader_drop_last \ +--eval_steps 1001 diff --git a/tianqing_examples/llama3_full_sft_ds3.yaml b/tianqing_examples/llama3_full_sft_ds3.yaml new file mode 100644 index 0000000000..c37060276e --- /dev/null +++ b/tianqing_examples/llama3_full_sft_ds3.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +### method +stage: sft +do_train: true +finetuning_type: full +parallel_mode: dist_flash_attn +deepspeed: examples/deepspeed/ds_z3_offload_config.json + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 2 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +fp16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500