Skip to content

Commit

Permalink
Adding support for Context Parallelism using Deepseed's DistributedAt…
Browse files Browse the repository at this point in the history
…tention (huggingface#1501)

Co-authored-by: regisss <[email protected]>
  • Loading branch information
bhargaveede and regisss authored Dec 3, 2024
1 parent 9f9b41e commit 0fbc457
Show file tree
Hide file tree
Showing 10 changed files with 625 additions and 11 deletions.
4 changes: 3 additions & 1 deletion optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
from accelerate.utils.other import is_compiled_module
from torch.optim.lr_scheduler import LRScheduler

from .. import parallel_state


if is_deepspeed_available():
from accelerate.utils import (
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
force_autocast: bool = False,
):
self.trackers = []
self.mpu = parallel_state
if project_config is not None:
self.project_configuration = project_config
else:
Expand Down Expand Up @@ -775,7 +778,6 @@ def _prepare_deepspeed(self, *args):
# This env variable is initialized here to make sure it is set to "true"
# It should be done by the launcher but it does not work for multi-node runs
os.environ["DEEPSPEED_USE_HPU"] = "true"

engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
# torch.compile should be called if dynamo plugin backend is set and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend == GaudiDynamoBackend.HPU_BACKEND and not is_compiled_module(
Expand Down
5 changes: 5 additions & 0 deletions optimum/habana/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from .. import parallel_state
from .state import GaudiAcceleratorState
from .utils.operations import (
broadcast,
Expand Down Expand Up @@ -306,6 +307,10 @@ def gaudi_prepare_data_loader(
num_processes = state.num_processes
if process_index is None:
process_index = state.process_index
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
num_processes = int(num_processes / parallel_state.get_sequence_parallel_world_size())
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
process_index = int(process_index / parallel_state.get_sequence_parallel_world_size())

# Sanity check
if split_batches:
Expand Down
6 changes: 5 additions & 1 deletion optimum/habana/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from optimum.utils import logging

from .. import parallel_state
from .utils import GaudiDistributedType


Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, cpu: bool = False, **kwargs):
if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
world_size, rank, local_rank = initialize_distributed_hpu()
self.backend = kwargs.pop("backend", "hccl")

context_parallel_size = kwargs.pop("context_parallel_size", 1)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
if not is_deepspeed_available():
raise ImportError(
Expand Down Expand Up @@ -85,6 +86,9 @@ def __init__(self, cpu: bool = False, **kwargs):
if self.device is None:
# TODO: replace by `torch.device("hpu", self.local_process_index)` when hpu:x is supported
self.device = torch.device("hpu")
if not is_deepspeed_available():
context_parallel_size = 1
parallel_state.initialize_model_parallel(sequence_parallel_size=context_parallel_size, use_fp8=False)
else:
self.distributed_type = (
GaudiDistributedType.NO
Expand Down
30 changes: 30 additions & 0 deletions optimum/habana/distributed/contextparallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

from ..parallel_state import (
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
)


# Gather losses across context parallel group
class _ContextParallelLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, loss):
ctx.seqlen = loss.size(0) * get_sequence_parallel_world_size()

loss_all = torch.empty(ctx.seqlen, dtype=loss.dtype, device=loss.device)
torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group())
return loss_all

@staticmethod
def backward(ctx, grad_output):
step_seqlen = ctx.seqlen // get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]

return grad_output_part, None


def _get_loss_from_context_parallel(vocab_parallel_loss):
return _ContextParallelLoss.apply(vocab_parallel_loss)
Loading

0 comments on commit 0fbc457

Please sign in to comment.