Skip to content

Commit

Permalink
xpu: refactor XPU worker & executor (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 4, 2024
1 parent 8b8d2ce commit 9094a8a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
32 changes: 11 additions & 21 deletions aphrodite/executor/xpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch
from loguru import logger
Expand All @@ -7,11 +7,11 @@
LoRAConfig, ModelConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
SamplerOutput)
from aphrodite.common.utils import make_async
from aphrodite.executor.executor_base import ExecutorAsyncBase
from aphrodite.executor.gpu_executor import GPUExecutor
from aphrodite.task_handler.worker_base import WorkerWrapperBase


class XPUExecutor(GPUExecutor):
Expand Down Expand Up @@ -49,28 +49,18 @@ def __init__(
# Instantiate the worker and load the model to GPU.
self._init_executor()

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "aphrodite.task_handler.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "aphrodite.task_handler.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

Expand Down
9 changes: 7 additions & 2 deletions aphrodite/task_handler/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."

self.multimodal_config = multimodal_config

Expand Down Expand Up @@ -175,7 +176,11 @@ def init_worker_distributed_environment(self) -> None:
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
Expand Down

0 comments on commit 9094a8a

Please sign in to comment.