Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked Prefill #656

Draft
wants to merge 3 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@

Run `pytest tests/models/test_chunked_prefill.py`.
"""
import os

Check failure on line 9 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:9:8: F401 `os` imported but unused
from contextlib import nullcontext

Check failure on line 10 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:10:24: F401 `contextlib.nullcontext` imported but unused

import pytest

from tests.kernels.utils import override_backend_env_variable

Check failure on line 14 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:14:33: F401 `tests.kernels.utils.override_backend_env_variable` imported but unused
from vllm.platforms import current_platform

Check failure on line 15 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:15:28: F401 `vllm.platforms.current_platform` imported but unused

from ..models.utils import check_logprobs_close, check_outputs_equal

Check failure on line 17 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:17:28: F401 `..models.utils.check_logprobs_close` imported but unused
from ..utils import multi_gpu_test

Check failure on line 18 in tests/basic_correctness/test_chunked_prefill.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/basic_correctness/test_chunked_prefill.py:18:21: F401 `..utils.multi_gpu_test` imported but unused

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-3.2-1B",
#"meta-llama/Llama-3.2-1B",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("chunked_prefill_token_size", [4,])
@pytest.mark.parametrize("enforce_eager", [True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
#@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -42,14 +42,12 @@
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
attention_backend: str,
monkeypatch,
) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
override_backend_env_variable(monkeypatch, attention_backend)
#override_backend_env_variable(monkeypatch, attention_backend)

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
Expand All @@ -76,7 +74,7 @@
)


@multi_gpu_test(num_gpus=2)
'''@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
Expand Down Expand Up @@ -320,4 +318,4 @@
chunk_size,
1,
dtype,
)
)'''
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
return x

if device is None:
device = "cpu" if current_platform.is_cpu() else "cuda"
device = "cpu" if current_platform.is_cpu() or current_platform.is_hpu() else "cuda"

Check failure on line 249 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/conftest.py:249:81: E501 Line too long (96 > 80)

if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
Expand Down
166 changes: 127 additions & 39 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu

import habana_frameworks.torch as htorch

logger = init_logger(__name__)

HPUFusedSDPA = None
Expand Down Expand Up @@ -95,6 +97,8 @@
cross_block_scales: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None
decode_slot_mapping: Optional[torch.Tensor] = None
decode_block_list: Optional[torch.Tensor] = None


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down Expand Up @@ -202,31 +206,99 @@
v_scale=v_scale,
)

batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
hidden_size: int = 0
prefill_query = query[:attn_metadata.num_prefill_tokens].clone()
prefill_key = key[:attn_metadata.num_prefill_tokens].clone()
prefill_value = value[:attn_metadata.num_prefill_tokens].clone()
decode_query = query[attn_metadata.num_prefill_tokens:].clone()
decode_key = key[attn_metadata.num_prefill_tokens:].clone()
decode_value = value[attn_metadata.num_prefill_tokens:].clone()
htorch.core.mark_step()
if attn_metadata.num_decode_tokens > 0:
import pdb; pdb.set_trace()

Check failure on line 218 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E702)

vllm/attention/backends/hpu_attn.py:218:23: E702 Multiple statements on one line (semicolon)
if attn_metadata.num_prefill_tokens > 0:
# prefill preprocessing
hidden_size = prefill_query.shape[-1]
# print(prefill_query.shape, hidden_size)
prefill_query = prefill_query.reshape(attn_metadata.num_prefills,
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,

Check failure on line 224 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:224:81: E501 Line too long (95 > 80)
hidden_size)
hidden_size = prefill_key.shape[-1]
# print(prefill_key.shape, hidden_size)
prefill_key = prefill_key.reshape(attn_metadata.num_prefills,exit

Check failure on line 228 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

invalid syntax. Perhaps you forgot a comma? [syntax]

Check failure on line 228 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

invalid syntax. Perhaps you forgot a comma? [syntax]

Check failure on line 228 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

invalid syntax. Perhaps you forgot a comma? [syntax]
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,

Check failure on line 229 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

invalid syntax [syntax]

Check failure on line 229 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff

vllm/attention/backends/hpu_attn.py:229:33: SyntaxError: Expected ',', found name
hidden_size)
hidden_size = prefill_value.shape[-1]
# print(prefill_value.shape, hidden_size)
prefill_value = prefill_value.reshape(attn_metadata.num_prefills,
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,
hidden_size)
prefill_batch_size, prefill_seq_len, prefill_hidden_size = prefill_query.shape
_, seq_len_kv, _ = prefill_key.shape
prefill_query = prefill_query.reshape(-1, self.num_heads, self.head_size)
prefill_key = prefill_key.reshape(-1, self.num_kv_heads, self.head_size)
prefill_value = prefill_value.reshape(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
prefill_key = prefill_key.unflatten(0, (block_indices.size(0), -1))
prefill_value = prefill_value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
prefill_key_cache = self.k_cache(prefill_key, key_cache, block_indices,
block_offsets)
prefill_value_cache = self.v_cache(prefill_value, value_cache, block_indices,
block_offsets)
htorch.core.mark_step()
if attn_metadata.num_decode_tokens > 0:
# decode preprocessing
import pdb; pdb.set_trace()
hidden_size = decode_query.shape[-1]
print(decode_query.shape, hidden_size)
decode_query = decode_query.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
hidden_size = decode_key.shape[-1]
print(decode_key.shape, hidden_size)
decode_key = decode_key.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
hidden_size = decode_value.shape[-1]
print(decode_value.shape, hidden_size)
decode_value = decode_value.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
decode_batch_size, decode_seq_len, decode_hidden_size = decode_query.shape
decode_query = decode_query.view(-1, self.num_heads, self.head_size)
decode_key = decode_key.view(-1, self.num_kv_heads, self.head_size)
decode_value = decode_value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
decode_key_cache = self.k_cache(decode_key, key_cache, block_indices,
block_offsets)
decode_value_cache = self.v_cache(decode_value, value_cache, block_indices,
block_offsets)
htorch.core.mark_step()
import pdb; pdb.set_trace()


prompt_output: torch.Tensor = None
decode_output: torch.Tensor = None
batch_size: int = 0
seq_len: int = 0
if attn_metadata.num_prefills > 0:
# Prompt run.
batch_size = prefill_batch_size
seq_len = prefill_seq_len
hidden_size = prefill_hidden_size
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
Expand All @@ -247,9 +319,9 @@
attn_bias = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
prefill_query.view(query_shape),
prefill_key.view(kv_shape),
prefill_value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
Expand All @@ -262,11 +334,11 @@
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
query=prefill_query.view(query_shape),
key=prefill_key.view(kv_shape),
value=prefill_value.view(kv_shape),
key_cache=prefill_key_cache,
value_cache=prefill_value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
scale=self.scale,
Expand All @@ -275,14 +347,22 @@
softmax_op=self.softmax,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
prompt_output = out.reshape(batch_size, seq_len, hidden_size)
htorch.core.mark_step()
if attn_metadata.num_decode_tokens > 0:
# Decoding run.
output = HPUPagedAttention.forward_decode(
import pdb; pdb.set_trace()
query = decode_query
key = decode_key
value = decode_value
batch_size = decode_batch_size
seq_len = decode_seq_len
hidden_size = decode_hidden_size
decode_output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
key_cache=decode_key_cache,
value_cache=decode_value_cache,
block_list=attn_metadata.decode_block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
Expand All @@ -294,8 +374,16 @@
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
htorch.core.mark_step()
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
if not decode_output:
return prompt_output.view(batch_size * seq_len, hidden_size)
elif not prompt_output:
return decode_output.view(batch_size * seq_len, hidden_size)
else:
prompt_output = prompt_output.view(batch_size * seq_len, hidden_size)
decode_output = decode_output.view(batch_size * seq_len, hidden_size)
return torch.cat((prompt_output, decode_output))

def forward_encoder_decoder(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class EngineArgs:
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
use_padding_aware_scheduling: bool = current_platform.is_hpu()
use_padding_aware_scheduling: bool = False
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def _prepare_seq_groups(
else:
# Decode
prompt_logprob_len = 0
query_len = query_lens[i] if query_lens is not None and len(
query_lens) > 0 else 1
query_len = 1
sample_len = len(seq_ids) * query_len if do_sample else 0

if sampling_params.seed is not None and generators is not None:
Expand Down
Loading
Loading