Skip to content

Commit

Permalink
Refactor torchrun executor to reuse single gpu executor code
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Mar 21, 2024
1 parent 9b1388c commit 6b186bb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 79 deletions.
89 changes: 13 additions & 76 deletions vllm/executor/torchrun_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Dict, List, Optional

from vllm.executor.gpu_executor import GPUExecutor
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
Expand All @@ -22,7 +23,7 @@
}


class TorchrunGPUExecutor(ExecutorBase):
class TorchrunGPUExecutor(GPUExecutor):

def __init__(
self,
Expand All @@ -33,27 +34,15 @@ def __init__(
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.is_driver_worker = self.local_rank == 0
super().__init__(model_config,
cache_config,
parallel_config,
scheduler_config,
device_config,
lora_config)

# Instantiate the worker and load the model to GPU.
self._init_worker()

# Profile the memory usage and initialize the cache.
self._init_cache()

def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker

def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
Expand All @@ -65,7 +54,7 @@ def _init_worker(self):

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.worker = Worker(
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
Expand All @@ -77,50 +66,15 @@ def _init_worker(self):
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=self.is_driver_worker,
)
self.worker.init_model()
self.worker.load_model()

def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks, num_cpu_blocks = (
self.worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.
gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
))

logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")

check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)

self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

# Initialize the cache.
self.worker.init_cache_engine(cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self.worker.warm_up_model()
self.driver_worker.init_model()
self.driver_worker.load_model()

def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.worker.execute_model(
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
Expand All @@ -134,23 +88,6 @@ def execute_model(self,
output = res[0]
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
return self.worker.list_loras()

def check_health(self) -> None:
# TorchrunGPUExecutor will always be healthy as long as
# it's running.
return


class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase):

async def execute_model_async(
Expand All @@ -160,7 +97,7 @@ async def execute_model_async(
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.worker.execute_model)(
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,13 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x):
#print(f'>>>Shape of x in mlp {x.shape} {self.gate_up_proj.weight.shape}')
if x.shape[0] == 1 and x.shape[1] == 1:

out = torch.empty(x.shape[0],self.gate_up_proj.weight.shape[0]//2,dtype=x.dtype,device=x.device)
custom_ops.LLMM_Silu(self.gate_up_proj.weight,x.view(-1,x.size(-1)),out,8)
x = out.view(x.shape[0], x.shape[1], out.shape[1])
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
#print(f'>>> x.shape {x.shape}')
x, _ = self.down_proj(x)
return x

Expand Down

0 comments on commit 6b186bb

Please sign in to comment.