Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: refactor XPU worker & executor #861

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions aphrodite/executor/xpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch
from loguru import logger
Expand All @@ -7,11 +7,11 @@
LoRAConfig, ModelConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
SamplerOutput)
from aphrodite.common.utils import make_async
from aphrodite.executor.executor_base import ExecutorAsyncBase
from aphrodite.executor.gpu_executor import GPUExecutor
from aphrodite.task_handler.worker_base import WorkerWrapperBase


class XPUExecutor(GPUExecutor):
Expand Down Expand Up @@ -49,28 +49,18 @@ def __init__(
# Instantiate the worker and load the model to GPU.
self._init_executor()

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "aphrodite.task_handler.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "aphrodite.task_handler.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

Expand Down
9 changes: 7 additions & 2 deletions aphrodite/task_handler/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."

self.multimodal_config = multimodal_config

Expand Down Expand Up @@ -175,7 +176,11 @@ def init_worker_distributed_environment(self) -> None:
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
Expand Down
Loading