From c9ce23111ebea6f3e6cabbe6471188865552a139 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 5 Nov 2024 15:12:06 +0200 Subject: [PATCH] Conform to new worker/model_runner APIs --- vllm/distributed/parallel_state.py | 4 +++ vllm/executor/hpu_executor.py | 8 +---- vllm/worker/hpu_model_runner.py | 51 ++++++++---------------------- vllm/worker/hpu_worker.py | 40 ++++------------------- 4 files changed, 24 insertions(+), 79 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d3ccd49797068..efa3525910a5e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -370,6 +370,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # TPU handles Dynamo with its own logic. return self.tpu_communicator.all_reduce(input_) + if self.hpu_communicator is not None and \ + not self.hpu_communicator.disabled: + return self.hpu_communicator.all_reduce(input_) + if self.ca_comm is not None and \ not self.ca_comm.disabled and \ self.ca_comm.should_custom_ar(input_): diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 34879bc4e7ef5..220e9eee87bb3 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -37,16 +37,10 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, is_driver_worker=rank == 0, ) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index ed71ba8f853a7..5008a2abd22ea 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -26,9 +26,7 @@ HabanaMemoryProfiler, format_bytes) from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import DeviceConfig, VllmConfig from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -516,36 +514,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", + vllm_config: VllmConfig, is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states - self.observability_config = observability_config - - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) + self.sliding_window = (self.model_config.get_sliding_window() + if self.model_config is not None else None) + self.device_config = (self.device_config if self.device_config + is not None else DeviceConfig()) self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs @@ -555,14 +535,13 @@ def __init__( self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens - self.block_size = cache_config.block_size + self.block_size = self.cache_config.block_size self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype + self.kv_cache_dtype = self.cache_config.cache_dtype self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, @@ -616,13 +595,7 @@ def load_model(self) -> None: htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) msg = ("Pre-loading model weights on " f"{next(self.model.parameters()).device} " f"took {m_getmodel.get_summary_string()}") @@ -901,6 +874,8 @@ def _prepare_prompt( num_prefill_tokens=sum_query_len, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps= + None # FIXME(kzawora): mutli-modality will not work here ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -1054,7 +1029,7 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - ) + multi_modal_placeholder_index_maps=None) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index e33926aa0b7fd..493f7a9fad098 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -12,10 +12,7 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -26,7 +23,8 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.hpu_model_runner import HPUModelRunner from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -41,34 +39,18 @@ class HPUWorker(LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[ModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -79,17 +61,7 @@ def __init__( init_cached_hf_modules() self.model_runner: HPUModelRunner = HPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config) + vllm_config=vllm_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[HPUCacheEngine]