Skip to content

Commit

Permalink
[core][distributed] custom allreduce when pp size > 1 (vllm-project#6117
Browse files Browse the repository at this point in the history
)
  • Loading branch information
youkaichao authored and jimpang committed Jul 8, 2024
1 parent 9f6d182 commit 34193d5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
16 changes: 5 additions & 11 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,17 +723,11 @@ def _verify_args(self) -> None:
if self.distributed_executor_backend == "ray":
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers "
Expand Down
16 changes: 12 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
)


def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
backend: str) -> GroupCoordinator:
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
use_custom_allreduce=use_custom_allreduce,
)


Expand Down Expand Up @@ -888,8 +893,11 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, backend)
get_world_group().local_rank,
backend,
use_custom_allreduce=False)


def ensure_model_parallel_initialized(
Expand Down

0 comments on commit 34193d5

Please sign in to comment.