Skip to content

Commit

Permalink
cpu
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 1, 2024
1 parent 14b8af4 commit b4670b5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 63 deletions.
25 changes: 6 additions & 19 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from torch import nn

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.config import EngineConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
Expand Down Expand Up @@ -388,29 +386,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: EngineConfig,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
ModelRunnerBase.__init__(self, vllm_config)
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
model_config = self.model_config
cache_config = self.cache_config

self.is_driver_worker = is_driver_worker

self.device = self.device_config.device
Expand Down
37 changes: 9 additions & 28 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, ModelConfig,
ParallelConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
Expand All @@ -18,7 +17,8 @@
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)

logger = init_logger(__name__)

Expand Down Expand Up @@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: EngineConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
WorkerBase.__init__(self, vllm_config=vllm_config)

self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config

self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
Expand All @@ -166,15 +154,8 @@ def __init__(
if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var

def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, base_model_runner: GPUModelRunnerBase):

# Check attention backend support.
supported_attention_backends: List[str] = \
Expand Down
15 changes: 1 addition & 14 deletions vllm/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
base_model_runner = self.model_runner
# for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelRunner(
base_model_runner,
base_model_runner.model_config,
base_model_runner.parallel_config,
base_model_runner.scheduler_config,
base_model_runner.device_config,
base_model_runner.cache_config,
load_config=base_model_runner.load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker,
prompt_adapter_config=base_model_runner.prompt_adapter_config,
observability_config=base_model_runner.observability_config,
)
self.model_runner = MultiStepModelRunner(base_model_runner, )

pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
self.multi_step_states: List[
Expand Down

0 comments on commit b4670b5

Please sign in to comment.