Skip to content

Commit

Permalink
[HPU][Bugfix] set_forward_context and CI test execution (vllm-project…
Browse files Browse the repository at this point in the history
…#12014)

Signed-off-by: Konrad Zawora <[email protected]>
  • Loading branch information
kzawora-intel authored Jan 14, 2025
1 parent 1a40125 commit 078da31
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
7 changes: 5 additions & 2 deletions .buildkite/run-hpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ set -ex
docker build -t hpu-test-env -f Dockerfile.hpu .

# Setup cleanup
EXITCODE=1
remove_docker_container() { docker rm -f hpu-test || true; }
trap remove_docker_container EXIT
remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
trap remove_docker_container_and_exit EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
EXITCODE=$?
2 changes: 1 addition & 1 deletion Dockerfile.hpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest

COPY ./ /workspace/vllm

Expand Down
32 changes: 17 additions & 15 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ def forward_hook(module, args, output):

class HpuModelAdapter:

def __init__(self, model, block_size, dtype, enforce_eager):
def __init__(self, model, vllm_config):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
enforce_eager = vllm_config.model_config.enforce_eager
if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
Expand Down Expand Up @@ -353,14 +355,20 @@ def forward(self, *args, **kwargs):
selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
virtual_engine = 0
if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
virtual_engine):
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0,
selected_token_indices)
return hidden_states

def compute_logits(self, *args, **kwargs):
Expand Down Expand Up @@ -660,10 +668,7 @@ def load_model(self) -> None:

with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph(
self.model,
self.block_size,
dtype=self.model_config.dtype,
enforce_eager=self.enforce_eager)
self.model, vllm_config=self.vllm_config)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)

Expand Down Expand Up @@ -1934,6 +1939,7 @@ def execute_model(
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
Expand All @@ -1948,11 +1954,7 @@ def execute_model(
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
with set_forward_context(
model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine), \
self.profiler.record_event(
'internal', model_event_name):
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
Expand Down

0 comments on commit 078da31

Please sign in to comment.