Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix for multinode crash on 4 PP #6495

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@


@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
[
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
# TODO: figure out why PP=4 tests are flaky
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args = [
Expand Down
14 changes: 14 additions & 0 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []

tp_driver_worker_ranks = []
non_driver_worker_ranks = []
for idx, rank in enumerate(worker_ranks[1:]):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx])
tp_driver_worker_ranks.append(rank)
else:
self.non_driver_workers.append(self.workers[idx])
non_driver_worker_ranks.append(rank)

# Enforce rank order for correct rank to return final output.
self.tp_driver_workers = [
worker for _, worker in sorted(
zip(tp_driver_worker_ranks, self.tp_driver_workers))
]
self.non_driver_workers = [
worker for _, worker in sorted(
zip(non_driver_worker_ranks, self.non_driver_workers))
]

def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
Expand Down
Loading