From 0fbc4576796a69eaca10166b50f556e43a741eae Mon Sep 17 00:00:00 2001 From: Bhargav Date: Tue, 3 Dec 2024 22:59:01 +0530 Subject: [PATCH] Adding support for Context Parallelism using Deepseed's DistributedAttention (#1501) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- optimum/habana/accelerate/accelerator.py | 4 +- optimum/habana/accelerate/data_loader.py | 5 + optimum/habana/accelerate/state.py | 6 +- optimum/habana/distributed/contextparallel.py | 30 ++ optimum/habana/parallel_state.py | 467 ++++++++++++++++++ .../models/llama/modeling_llama.py | 62 ++- optimum/habana/transformers/trainer.py | 17 +- optimum/habana/transformers/training_args.py | 6 + tests/baselines/llama_7b.json | 29 ++ tests/test_examples.py | 10 + 10 files changed, 625 insertions(+), 11 deletions(-) create mode 100644 optimum/habana/distributed/contextparallel.py create mode 100644 optimum/habana/parallel_state.py diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index e5c0ea9ea9..e5fb539a5b 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -59,6 +59,8 @@ from accelerate.utils.other import is_compiled_module from torch.optim.lr_scheduler import LRScheduler +from .. import parallel_state + if is_deepspeed_available(): from accelerate.utils import ( @@ -123,6 +125,7 @@ def __init__( force_autocast: bool = False, ): self.trackers = [] + self.mpu = parallel_state if project_config is not None: self.project_configuration = project_config else: @@ -775,7 +778,6 @@ def _prepare_deepspeed(self, *args): # This env variable is initialized here to make sure it is set to "true" # It should be done by the launcher but it does not work for multi-node runs os.environ["DEEPSPEED_USE_HPU"] = "true" - engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) # torch.compile should be called if dynamo plugin backend is set and only if the model isn't already compiled. if self.state.dynamo_plugin.backend == GaudiDynamoBackend.HPU_BACKEND and not is_compiled_module( diff --git a/optimum/habana/accelerate/data_loader.py b/optimum/habana/accelerate/data_loader.py index ae00b8976d..afe8fc1cd8 100644 --- a/optimum/habana/accelerate/data_loader.py +++ b/optimum/habana/accelerate/data_loader.py @@ -22,6 +22,7 @@ ) from torch.utils.data import BatchSampler, DataLoader, IterableDataset +from .. import parallel_state from .state import GaudiAcceleratorState from .utils.operations import ( broadcast, @@ -306,6 +307,10 @@ def gaudi_prepare_data_loader( num_processes = state.num_processes if process_index is None: process_index = state.process_index + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + num_processes = int(num_processes / parallel_state.get_sequence_parallel_world_size()) + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + process_index = int(process_index / parallel_state.get_sequence_parallel_world_size()) # Sanity check if split_batches: diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index 8d6c39af38..b9c4e794f7 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -21,6 +21,7 @@ from optimum.utils import logging +from .. import parallel_state from .utils import GaudiDistributedType @@ -50,7 +51,7 @@ def __init__(self, cpu: bool = False, **kwargs): if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: world_size, rank, local_rank = initialize_distributed_hpu() self.backend = kwargs.pop("backend", "hccl") - + context_parallel_size = kwargs.pop("context_parallel_size", 1) if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": if not is_deepspeed_available(): raise ImportError( @@ -85,6 +86,9 @@ def __init__(self, cpu: bool = False, **kwargs): if self.device is None: # TODO: replace by `torch.device("hpu", self.local_process_index)` when hpu:x is supported self.device = torch.device("hpu") + if not is_deepspeed_available(): + context_parallel_size = 1 + parallel_state.initialize_model_parallel(sequence_parallel_size=context_parallel_size, use_fp8=False) else: self.distributed_type = ( GaudiDistributedType.NO diff --git a/optimum/habana/distributed/contextparallel.py b/optimum/habana/distributed/contextparallel.py new file mode 100644 index 0000000000..0b48465542 --- /dev/null +++ b/optimum/habana/distributed/contextparallel.py @@ -0,0 +1,30 @@ +import torch + +from ..parallel_state import ( + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) + + +# Gather losses across context parallel group +class _ContextParallelLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, loss): + ctx.seqlen = loss.size(0) * get_sequence_parallel_world_size() + + loss_all = torch.empty(ctx.seqlen, dtype=loss.dtype, device=loss.device) + torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group()) + return loss_all + + @staticmethod + def backward(ctx, grad_output): + step_seqlen = ctx.seqlen // get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] + + return grad_output_part, None + + +def _get_loss_from_context_parallel(vocab_parallel_loss): + return _ContextParallelLoss.apply(vocab_parallel_loss) diff --git a/optimum/habana/parallel_state.py b/optimum/habana/parallel_state.py new file mode 100644 index 0000000000..c370d88229 --- /dev/null +++ b/optimum/habana/parallel_state.py @@ -0,0 +1,467 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Model and data parallel groups.""" + +from typing import Optional + +import torch + + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_GLOO = None +# FP8 amax reduction group. +_AMAX_REDUCTION_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +_TRAINING_MODE = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# For DeepSpeed's sequence parallel +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_WORLD_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +# This group includes processes for both data and sequence parallelisms. +# We use this group to reduce gradients and shard parameters and optimizer stages for ZeRO. +_SEQUENCE_DATA_PARALLEL_GROUP = None +_SEQUENCE_DATA_PARALLEL_WORLD_SIZE = None +_SEQUENCE_DATA_PARALLEL_RANK = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + +# Memory buffers to avoid dynamic memory allocation +_GLOBAL_MEMORY_BUFFER = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + sequence_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_split_rank: Optional[int] = None, + use_fp8: bool = False, + use_distributed_optimizer: bool = False, +) -> None: + """Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size (int, default = 1): + The number of GPUs to split individual tensors across. + + pipeline_model_parallel_size (int, default = 1): + The number of tensor parallel GPU groups to split the + Transformer layers across. For example, if + tensor_model_parallel_size is 4 and + pipeline_model_parallel_size is 2, the model will be split + into 2 groups of 4 GPUs. + + virtual_pipeline_model_parallel_size (int, optional): + The number of stages that each pipeline group will have, + interleaving as necessary. If None, no interleaving is + performed. For example, if tensor_model_parallel_size is 1, + pipeline_model_parallel_size is 4, + virtual_pipeline_model_parallel_size is 2, and there are + 16 transformer layers in the model, the model will be + split into 8 stages with two layers each and each GPU + would get 2 stages as such (layer number starting with 1): + + GPU 0: [1, 2] [9, 10] + GPU 1: [3, 4] [11, 12] + GPU 2: [5, 6] [13, 14] + GPU 3: [7, 8] [15, 16] + + pipeline_model_parallel_split_rank (int, optional): + For models with both an encoder and decoder, the rank in + pipeline to switch between encoder and decoder (i.e. the + first rank of the decoder). This allows the user to set + the pipeline parallel size of the encoder and decoder + independently. For example, if + pipeline_model_parallel_size is 8 and + pipeline_model_parallel_split_rank is 3, then ranks 0-2 + will be the encoder and ranks 3-7 will be the decoder. + + use_fp8 (bool, default = False): + Construct GPU groups needed for FP8 training, namely for + amax reduction across the product of the data-parallel and + tensor-parallel groups. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + enable_ds_sequence_parallel = sequence_parallel_size > 1 + if enable_ds_sequence_parallel: + assert ( + tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1 + ), "DeepSpeed's sequence parallel does not work with tensor parallel or pipeline parallel" + + if world_size % sequence_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})" + ) + + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size + ) + sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + # num_data_parallel_groups: int = world_size // data_parallel_size + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 2: + raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " "interleaved schedule") + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size + + if pipeline_model_parallel_split_rank is not None: + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_GLOO + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + + if sequence_parallel_size > 1: + tp_or_sp_size = sequence_parallel_size + else: + tp_or_sp_size = tensor_model_parallel_size + + for j in range(tp_or_sp_size): + ranks = range(start_rank + j, end_rank, tp_or_sp_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if use_distributed_optimizer: + group_gloo = torch.distributed.new_group(ranks, backend="gloo") + else: + group_gloo = None + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_GLOO = group_gloo + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized" + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + _SEQUENCE_PARALLEL_WORLD_SIZE = sequence_parallel_size + + global _TRAINING_MODE + _TRAINING_MODE = True + + # Build the sequence data parallel groups. + global _SEQUENCE_DATA_PARALLEL_GROUP + assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized" + all_data_sequence_parallel_group_ranks = [] + if enable_ds_sequence_parallel: + for i in range(num_sequence_data_parallel_groups): + ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size) + group = torch.distributed.new_group(ranks) + all_data_sequence_parallel_group_ranks.append(list(ranks)) + if rank in ranks: + _SEQUENCE_DATA_PARALLEL_GROUP = group + else: + _SEQUENCE_DATA_PARALLEL_GROUP = _DATA_PARALLEL_GROUP + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + num_model_parallel_groups = sequence_data_parallel_size if enable_ds_sequence_parallel else data_parallel_size + model_parallel_group_ranks = ( + all_data_sequence_parallel_group_ranks if enable_ds_sequence_parallel else all_data_parallel_group_ranks + ) + for i in range(num_model_parallel_groups): + ranks = [parallel_group_ranks[i] for parallel_group_ranks in model_parallel_group_ranks] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, "embedding group is already initialized" + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, "position embedding group is already initialized" + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + if pipeline_model_parallel_split_rank is not None: + if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: + embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]] + if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: + position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + # Build the FP8 groups. + global _AMAX_REDUCTION_GROUP + assert _AMAX_REDUCTION_GROUP is None, "FP8 amax reduction group is already initialized" + if use_fp8: + amax_group_size: int = tensor_model_parallel_size * data_parallel_size + num_amax_groups: int = world_size // amax_group_size + for i in range(num_amax_groups): + start_rank = i * amax_group_size + end_rank = (i + 1) * amax_group_size + ranks = range(start_rank, end_rank) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _AMAX_REDUCTION_GROUP = group + + +def is_unitialized(): + """Useful for code segments that may be accessed with or without mpu initialization""" + return _DATA_PARALLEL_GROUP is None + + +def is_training_mode(): + """Useful for code segments that may be accessed with or without mpu initialization""" + global _TRAINING_MODE + if _TRAINING_MODE is True: + return True + else: + return False + + +def set_training_mode(): + """Useful for code segments that may be accessed with or without mpu initialization""" + global _TRAINING_MODE + _TRAINING_MODE = True + + +def set_eval_mode(): + global _TRAINING_MODE + _TRAINING_MODE = False + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def sequence_parallel_is_initialized(): + """Check if sequence and data parallel groups are initialized.""" + if _SEQUENCE_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def sequence_data_parallel_is_initialized(): + """Check if sequence data parallel groups are initialized.""" + if _SEQUENCE_DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + +def get_model_parallel_world_size(): + return None + + +def get_model_parallel_rank(): + return 0 + + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_data_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized" + return _SEQUENCE_DATA_PARALLEL_GROUP + + +def set_sequence_parallel_world_size(world_size): + """Set the sequence parallel size""" + global _SEQUENCE_PARALLEL_WORLD_SIZE + _SEQUENCE_PARALLEL_WORLD_SIZE = world_size + + +def set_sequence_data_parallel_world_size(world_size): + """Set the sequence parallel size""" + global _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + _SEQUENCE_DATA_PARALLEL_WORLD_SIZE = world_size + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_WORLD_SIZE + if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: + return _SEQUENCE_PARALLEL_WORLD_SIZE + # Context Parallelism is not yet supported for eval + if is_training_mode(): + return torch.distributed.get_world_size(group=get_sequence_parallel_group()) + else: + return 1 + + +def get_sequence_data_parallel_world_size(): + """Return world size for the sequence parallel group.""" + global _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + if _SEQUENCE_DATA_PARALLEL_WORLD_SIZE is not None: + return _SEQUENCE_DATA_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_sequence_data_parallel_group()) + + +def get_data_parallel_world_size(): + return get_sequence_data_parallel_world_size() + + +def get_data_parallel_group(): + return get_sequence_data_parallel_group() + + +def set_sequence_parallel_rank(rank): + """Set sequence parallel rank.""" + global _SEQUENCE_PARALLEL_RANK + _SEQUENCE_PARALLEL_RANK = rank + + +def set_sequence_data_parallel_rank(rank): + """Set sequence parallel rank.""" + global _SEQUENCE_DATA_PARALLEL_RANK + _SEQUENCE_DATA_PARALLEL_RANK = rank + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + # Context Parallelism is not yet supported for eval + if is_training_mode(): + return torch.distributed.get_rank(group=get_sequence_parallel_group()) + else: + return 0 + + +def get_sequence_data_parallel_rank(): + """Return my rank for the sequence data parallel group.""" + global _SEQUENCE_DATA_PARALLEL_RANK + if _SEQUENCE_DATA_PARALLEL_RANK is not None: + return _SEQUENCE_DATA_PARALLEL_RANK + return torch.distributed.get_rank(group=get_sequence_data_parallel_group()) + + +def get_sequence_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the sequence parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_sequence_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 5172e80492..62c4b6cb21 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -21,7 +21,7 @@ ) from transformers.utils import is_torchdynamo_compiling -from .... import distributed +from .... import distributed, parallel_state from ....distributed.strategy import DistributedStrategy, NoOpStrategy from ....distributed.tensorparallel import ( reduce_from_tensor_model_parallel_region, @@ -119,7 +119,7 @@ def __init__( self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings # Truncate the cached max sequence length to 8k to limit cached register buffer size - if config.max_position_embeddings > 8192 and self.rope_type == "llama3": + if not self.training and config.max_position_embeddings > 8192 and self.rope_type == "llama3": self.max_seq_len_cached = 8192 self.original_max_seq_len = config.max_position_embeddings @@ -436,6 +436,13 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) +def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return fused_scaled_dot_product_attention_distributed + else: + return fused_scaled_dot_product_attention + + class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -444,6 +451,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -470,6 +478,15 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): if FusedSDPA else None ) + # https://github.com/microsoft/DeepSpeed/issues/4359 + # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices + self.fused_scaled_dot_product_attention_distributed = None + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + from deepspeed.sequence.layer import DistributedAttention + + self.fused_scaled_dot_product_attention_distributed = DistributedAttention( + self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2 + ) def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" @@ -611,8 +628,22 @@ def pre_attn_forward( # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # else: # cos, sin = position_embeddings + seq_len = kv_seq_len + if parallel_state.sequence_parallel_is_initialized(): + seq_len = kv_seq_len * parallel_state.get_sequence_parallel_world_size() + + cos, sin = self.rotary_emb(value_states, seq_len=seq_len) + # If sequence parallel in enabled, position_ids should be based on which part of the sequence is present in the rank + # As we divide the inputs based on ranks, position_ids are generated to suit that part of the sequence + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_rank() > 0: + position_ids = torch.arange( + kv_seq_len * parallel_state.get_sequence_parallel_rank(), + kv_seq_len * (parallel_state.get_sequence_parallel_rank() + 1), + dtype=torch.long, + device=query_states.device, + ) + position_ids = position_ids.unsqueeze(0) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) if use_cache: @@ -659,11 +690,13 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - + fused_scaled_dot_product_attention = GaudiDistributedAttention( + self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed + ) if use_flash_attention and FusedSDPA is not None: if q_len == 1: # next token - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -680,7 +713,8 @@ def pre_attn_forward( # first token softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: - attn_output = self.fused_scaled_dot_product_attention( + # causal masking on first token requires inputs to be of the same length + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -694,7 +728,7 @@ def pre_attn_forward( "left", ) else: - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -1398,7 +1432,19 @@ def forward( shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # Collect losses from context parallel group + # Each rank in group calculates loss on partial outputs + if ( + parallel_state.sequence_parallel_is_initialized() + and parallel_state.get_sequence_parallel_world_size() > 1 + ): + from optimum.habana.distributed.contextparallel import _get_loss_from_context_parallel + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss_all = _get_loss_from_context_parallel(loss_fct(shift_logits, shift_labels)) + loss = torch.mean(loss_all) + else: + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index fc79204051..08fb914ccc 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -592,6 +592,11 @@ def _inner_training_loop( # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + if ( + self.accelerator.mpu.sequence_parallel_is_initialized() + and self.accelerator.mpu.get_sequence_parallel_world_size() > 1 + ): + total_train_batch_size = total_train_batch_size / self.accelerator.mpu.get_sequence_parallel_world_size() len_dataloader = None num_train_tokens = None @@ -1538,6 +1543,15 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): + if ( + self.accelerator.mpu.sequence_parallel_is_initialized() + and self.accelerator.mpu.get_sequence_parallel_world_size() > 1 + ): + seq_parallel_world_rank = self.accelerator.mpu.get_sequence_parallel_rank() + sub_seq_length = int(data.size()[1] / self.accelerator.mpu.get_sequence_parallel_world_size()) + data = data[ + :, seq_parallel_world_rank * sub_seq_length : (seq_parallel_world_rank + 1) * sub_seq_length + ] kwargs = {"device": self.args.device} if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): # NLP models inputs are int/uint and those get adjusted to the right dtype of the @@ -1789,7 +1803,6 @@ def evaluate( self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) - return output.metrics def predict( @@ -1975,6 +1988,8 @@ def evaluation_loop( all_losses.add(losses) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + if self.args.context_parallel_size != 1: + labels = labels.clone() labels = self.gather_function((labels)) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 96e7dd7956..79e43efa72 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -190,6 +190,11 @@ class GaudiTrainingArguments(TrainingArguments): }, ) + context_parallel_size: Optional[int] = field( + default=1, + metadata={"help": ("Determines how many ranks are divided into context parallel group.")}, + ) + throughput_warmup_steps: Optional[int] = field( default=0, metadata={ @@ -921,6 +926,7 @@ def _setup_devices(self) -> "torch.device": else: accelerator_state_kwargs["backend"] = self.ddp_backend accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + accelerator_state_kwargs["context_parallel_size"] = self.context_parallel_size else: raise ValueError( "No device has been set. Use either --use_habana to run on HPU or --no_cuda to run on CPU." diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json index a5235a8de1..0e8758801e 100644 --- a/tests/baselines/llama_7b.json +++ b/tests/baselines/llama_7b.json @@ -613,6 +613,35 @@ ] } } + }, + "tatsu-lab/alpaca_cp": { + "num_train_epochs": 1, + "eval_batch_size": 4, + "distribution": { + "deepspeed": { + "learning_rate": 3e-4, + "train_batch_size": 8, + "perplexity": 2.3889, + "train_runtime": 147.3597, + "train_samples_per_second": 34.41, + "extra_arguments": [ + "--bf16 True", + "--gradient_accumulation_steps 4", + "--logging_steps 1", + "--lora_rank 8", + "--lora_alpha 16", + "--lora_dropout 0.05", + "--lora_target_modules q_proj v_proj", + "--dataset_concatenation", + "--max_seq_length 2048", + "--pipelining_fwd_bwd", + "--throughput_warmup_steps 3", + "--use_lazy_mode", + "--context_parallel_size 4", + "--deepspeed tests/configs/deepspeed_zero_1.json" + ] + } + } } } } \ No newline at end of file diff --git a/tests/test_examples.py b/tests/test_examples.py index 663d4bcd38..83d1fd7e3b 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -241,6 +241,7 @@ def to_test( "codellama/CodeLlama-13b-Instruct-hf", "MIT/ast-finetuned-speech-commands-v2", "meta-llama/LlamaGuard-7b", + "huggyllama/llama-7b", ] case_only_in_gaudi2 = [ @@ -283,6 +284,7 @@ def to_test( "ia3", "adalora", "ln_tuning", + "tatsu-lab/alpaca_cp", ): return False elif eager_mode and model_name not in models_measured_on_eager_mode: @@ -321,6 +323,8 @@ def to_test( return True elif "ast-finetuned-speech-commands-v2" in model_name and IS_GAUDI2: return True + elif "huggyllama" in model_name and IS_GAUDI2 and deepspeed: + return True elif "gemma" in model_name and IS_GAUDI2: return True @@ -1020,4 +1024,10 @@ class MultiCardCausalLanguageModelingAdaloraExampleTester( ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True ): TASK_NAME = "adalora" + + +class MultiCardCausalLanguageModelingLoRACPExampleTester( + ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", deepspeed=True +): + TASK_NAME = "tatsu-lab/alpaca_cp" DATASET_NAME = "tatsu-lab/alpaca"