From 9094a8a2a3eea8c9d26adb4b28048e7d835526a9 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Wed, 4 Dec 2024 00:34:40 -0800 Subject: [PATCH] xpu: refactor XPU worker & executor (#861) --- aphrodite/executor/xpu_executor.py | 32 ++++++++++------------------ aphrodite/task_handler/xpu_worker.py | 9 ++++++-- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/aphrodite/executor/xpu_executor.py b/aphrodite/executor/xpu_executor.py index ecc2cfd4a..c6ef999ee 100644 --- a/aphrodite/executor/xpu_executor.py +++ b/aphrodite/executor/xpu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch from loguru import logger @@ -7,11 +7,11 @@ LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) -from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput +from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput, + SamplerOutput) from aphrodite.common.utils import make_async from aphrodite.executor.executor_base import ExecutorAsyncBase from aphrodite.executor.gpu_executor import GPUExecutor -from aphrodite.task_handler.worker_base import WorkerWrapperBase class XPUExecutor(GPUExecutor): @@ -49,28 +49,18 @@ def __init__( # Instantiate the worker and load the model to GPU. self._init_executor() - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - if self.speculative_config is None: - worker_module_name = "aphrodite.task_handler.xpu_worker" - worker_class_name = "XPUWorker" - else: + def _get_worker_module_and_class(self) -> Tuple[str, str]: + if self.speculative_config is not None: raise NotImplementedError( "XPU does not support speculative decoding") - - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) - wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return wrapper.worker + else: + worker_module_name = "aphrodite.task_handler.xpu_worker" + worker_class_name = "XPUWorker" + return (worker_module_name, worker_class_name) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: output = self.driver_worker.execute_model(execute_model_req) return output diff --git a/aphrodite/task_handler/xpu_worker.py b/aphrodite/task_handler/xpu_worker.py index e28fc39f3..3ded25965 100644 --- a/aphrodite/task_handler/xpu_worker.py +++ b/aphrodite/task_handler/xpu_worker.py @@ -63,8 +63,9 @@ def __init__( self.lora_config = lora_config self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." self.multimodal_config = multimodal_config @@ -175,7 +176,11 @@ def init_worker_distributed_environment(self) -> None: # dependency (libdrm and drm headers) on your system. ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "sockets") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(parallel_config.world_size)) os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) init_distributed_environment( world_size=parallel_config.world_size, rank=rank,