diff --git a/vllm/executor/torchrun_gpu_executor.py b/vllm/executor/torchrun_gpu_executor.py index 88823ba5d4920..837e6e9368e77 100644 --- a/vllm/executor/torchrun_gpu_executor.py +++ b/vllm/executor/torchrun_gpu_executor.py @@ -2,6 +2,7 @@ import os from typing import Dict, List, Optional +from vllm.executor.gpu_executor import GPUExecutor from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) @@ -22,7 +23,7 @@ } -class TorchrunGPUExecutor(ExecutorBase): +class TorchrunGPUExecutor(GPUExecutor): def __init__( self, @@ -33,27 +34,15 @@ def __init__( device_config: DeviceConfig, lora_config: Optional[LoRAConfig], ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.is_driver_worker = self.local_rank == 0 + super().__init__(model_config, + cache_config, + parallel_config, + scheduler_config, + device_config, + lora_config) - # Instantiate the worker and load the model to GPU. - self._init_worker() - - # Profile the memory usage and initialize the cache. - self._init_cache() - - def _dispatch_worker(self): - worker_module = DEVICE_TO_WORKER_MODULE_MAP[ - self.device_config.device_type] - imported_worker = importlib.import_module(worker_module) - Worker = imported_worker.Worker - return Worker def _init_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -65,7 +54,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.worker = Worker( + self.driver_worker = Worker( self.model_config, self.parallel_config, self.scheduler_config, @@ -77,50 +66,15 @@ def _init_worker(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=self.is_driver_worker, ) - self.worker.init_model() - self.worker.load_model() - - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine first profiles the existing memory usage. - Then, it allocates the remaining memory for KV blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_gpu_blocks, num_cpu_blocks = ( - self.worker.profile_num_available_blocks( - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config. - gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - )) - - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self.worker.init_cache_engine(cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self.worker.warm_up_model() + self.driver_worker.init_model() + self.driver_worker.load_model() def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - output = self.worker.execute_model( + output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, @@ -134,23 +88,6 @@ def execute_model(self, output = res[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.worker.remove_lora(lora_id) - - def list_loras(self) -> List[int]: - return self.worker.list_loras() - - def check_health(self) -> None: - # TorchrunGPUExecutor will always be healthy as long as - # it's running. - return - - class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase): async def execute_model_async( @@ -160,7 +97,7 @@ async def execute_model_async( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> SamplerOutput: - output = await make_async(self.worker.execute_model)( + output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6eb645ca675d9..aff5369c74e6c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -77,16 +77,13 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - #print(f'>>>Shape of x in mlp {x.shape} {self.gate_up_proj.weight.shape}') if x.shape[0] == 1 and x.shape[1] == 1: - out = torch.empty(x.shape[0],self.gate_up_proj.weight.shape[0]//2,dtype=x.dtype,device=x.device) custom_ops.LLMM_Silu(self.gate_up_proj.weight,x.view(-1,x.size(-1)),out,8) x = out.view(x.shape[0], x.shape[1], out.shape[1]) else: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - #print(f'>>> x.shape {x.shape}') x, _ = self.down_proj(x) return x