Skip to content

Commit

Permalink
Conform to new worker/model_runner APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Nov 5, 2024
1 parent bb512dd commit c9ce231
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 79 deletions.
4 changes: 4 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)

if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)

if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
Expand Down
8 changes: 1 addition & 7 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,10 @@ def _get_worker_kwargs(
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=rank == 0,
)

Expand Down
51 changes: 13 additions & 38 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
HabanaMemoryProfiler, format_bytes)

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -516,36 +514,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
vllm_config: VllmConfig,
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.return_hidden_states = return_hidden_states
self.observability_config = observability_config

self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())

self.sliding_window = (self.model_config.get_sliding_window()
if self.model_config is not None else None)
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
Expand All @@ -555,14 +535,13 @@ def __init__(
self.max_model_len = self.scheduler_config.max_model_len
self.max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
self.block_size = cache_config.block_size
self.block_size = self.cache_config.block_size

self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.kv_cache_dtype = self.cache_config.cache_dtype

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
Expand Down Expand Up @@ -616,13 +595,7 @@ def load_model(self) -> None:
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)
msg = ("Pre-loading model weights on "
f"{next(self.model.parameters()).device} "
f"took {m_getmodel.get_summary_string()}")
Expand Down Expand Up @@ -901,6 +874,8 @@ def _prepare_prompt(
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

Expand Down Expand Up @@ -1054,7 +1029,7 @@ def _prepare_decode(
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
)
multi_modal_placeholder_index_maps=None)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
Expand Down
40 changes: 6 additions & 34 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes

import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
Expand All @@ -26,7 +23,8 @@
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)

logger = init_logger(__name__)

Expand All @@ -41,34 +39,18 @@ class HPUWorker(LocalOrDistributedWorkerBase):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
Expand All @@ -79,17 +61,7 @@ def __init__(
init_cached_hf_modules()

self.model_runner: HPUModelRunner = HPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
observability_config=observability_config)
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
Expand Down

0 comments on commit c9ce231

Please sign in to comment.