diff --git a/aphrodite/executor/gpu_executor.py b/aphrodite/executor/gpu_executor.py index 5249545bf..c4ca3109d 100644 --- a/aphrodite/executor/gpu_executor.py +++ b/aphrodite/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from loguru import logger @@ -9,13 +9,16 @@ from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase from aphrodite.lora.request import LoRARequest from aphrodite.prompt_adapter.request import PromptAdapterRequest -from aphrodite.task_handler.worker_base import WorkerWrapperBase +from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase -def create_worker(worker_module_name, worker_class_name, **kwargs): +def create_worker(worker_module_name: str, worker_class_name: str, + worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], + **kwargs): wrapper = WorkerWrapperBase( worker_module_name=worker_module_name, worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, ) wrapper.init_worker(**kwargs) return wrapper.worker @@ -61,7 +64,9 @@ def _get_worker_kwargs( or (rank % self.parallel_config.tensor_parallel_size == 0), ) - def _get_worker_module_and_class(self) -> Tuple[str, str]: + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None if self.scheduler_config.is_multi_step: worker_module_name = "aphrodite.task_handler.multi_step_worker" worker_class_name = "MultiStepWorker" @@ -71,7 +76,7 @@ def _get_worker_module_and_class(self) -> Tuple[str, str]: else: worker_module_name = "aphrodite.task_handler.worker" worker_class_name = "Worker" - return (worker_module_name, worker_class_name) + return (worker_module_name, worker_class_name, worker_class_fn) def _get_create_worker_kwargs( self, @@ -80,10 +85,13 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - (worker_module_name, - worker_class_name) = self._get_worker_module_and_class() - worker_kwargs.update(worker_module_name=worker_module_name, - worker_class_name=worker_class_name) + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + worker_kwargs.update( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) return worker_kwargs diff --git a/aphrodite/executor/ray_gpu_executor.py b/aphrodite/executor/ray_gpu_executor.py index de6979132..87e13b964 100644 --- a/aphrodite/executor/ray_gpu_executor.py +++ b/aphrodite/executor/ray_gpu_executor.py @@ -101,12 +101,13 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, - worker_class_name) = self._get_worker_module_and_class() + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() return dict( worker_module_name=worker_module_name, worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, trust_remote_code=self.model_config.trust_remote_code, ) diff --git a/aphrodite/executor/xpu_executor.py b/aphrodite/executor/xpu_executor.py index c6ef999ee..62f2eded0 100644 --- a/aphrodite/executor/xpu_executor.py +++ b/aphrodite/executor/xpu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch from loguru import logger @@ -12,6 +12,7 @@ from aphrodite.common.utils import make_async from aphrodite.executor.executor_base import ExecutorAsyncBase from aphrodite.executor.gpu_executor import GPUExecutor +from aphrodite.task_handler.worker_base import WorkerBase class XPUExecutor(GPUExecutor): @@ -48,15 +49,16 @@ def __init__( # Instantiate the worker and load the model to GPU. self._init_executor() - - def _get_worker_module_and_class(self) -> Tuple[str, str]: + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None if self.speculative_config is not None: raise NotImplementedError( "XPU does not support speculative decoding") else: worker_module_name = "aphrodite.task_handler.xpu_worker" worker_class_name = "XPUWorker" - return (worker_module_name, worker_class_name) + return (worker_module_name, worker_class_name, worker_class_fn) def execute_model( self, execute_model_req: ExecuteModelRequest