Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add HPU support to vLLM v1 #487

Draft
wants to merge 36 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2191184
vLLM v1 HPU prototype
kzawora-intel Nov 12, 2024
fd77180
copy gpu model runner code, add hpugraphs support and profile run
kzawora-intel Nov 12, 2024
4dadef5
i am very much struggling
kzawora-intel Nov 13, 2024
9db1409
it's hopeless
kzawora-intel Nov 13, 2024
3b3098c
[wip] bypass prefill chunking in v1 scheduler
kzawora-intel Nov 14, 2024
c24adb5
colonoscopy
kzawora-intel Nov 14, 2024
2da069e
prefill runs, decode has deadlock, idk why
kzawora-intel Nov 14, 2024
932ce93
i'm done for today
kzawora-intel Nov 14, 2024
fc6a1c2
do better job at prefill chunking detection
kzawora-intel Nov 14, 2024
ff0ed54
mixed batch scheduling is still a problem
kzawora-intel Nov 15, 2024
50aa6b3
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 18, 2024
debec16
general hpu code rewrite
kzawora-intel Nov 18, 2024
0c1d0b6
add debug stuff, it seems like prefill is functional
kzawora-intel Nov 18, 2024
35d3e38
slight code cleanup
kzawora-intel Nov 18, 2024
491f991
remove garbage changes
kzawora-intel Nov 18, 2024
e29b84a
gsm8k now produces 69% acc on llama3.1
kzawora-intel Nov 19, 2024
27b4f32
add config not warmed up warnings
kzawora-intel Nov 19, 2024
087b5d2
add bucketinggit add -u .!
kzawora-intel Nov 19, 2024
6fdb6a9
llama3.1 now gives 81% in gsm8k without contiguous pa
kzawora-intel Nov 19, 2024
8714f9d
disable contiguous pa by default
kzawora-intel Nov 19, 2024
40ff0ac
async data copy
kzawora-intel Nov 19, 2024
28f2ac5
add split sampler optimization
kzawora-intel Nov 19, 2024
df7a1d4
add prompt batching
kzawora-intel Nov 20, 2024
623ed10
padded logits_indices and sampling + documentation
kzawora-intel Nov 20, 2024
0371c31
update docs
kzawora-intel Nov 20, 2024
d7b2a06
fix first-party random and greedy sampler for hpu
kzawora-intel Nov 20, 2024
c934e60
format.sh
kzawora-intel Nov 20, 2024
e0f4c26
add warmup w/ sampler (it doesn't work great tho)
kzawora-intel Nov 21, 2024
58c8f5d
add hpugraph check
kzawora-intel Nov 21, 2024
0c8b075
fix async engine, fix sampler corner cases
kzawora-intel Nov 22, 2024
fecedb5
Add padding-aware scheduling
kzawora-intel Nov 25, 2024
2ab1ac8
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 25, 2024
0d41073
bucketing refactor, enable contiguous pa, defrag blocks
kzawora-intel Nov 26, 2024
5645523
FreeKVCacheBlockHeapQueue bugfixes
kzawora-intel Nov 26, 2024
fd62723
[wip] add prefix caching support (it was actually really hard)
kzawora-intel Nov 26, 2024
e80f2be
fix hpugraphs and long seq corner case
kzawora-intel Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
HPU_ATTN_V1 = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down Expand Up @@ -172,6 +173,10 @@
return FlashInferBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1
return HPUAttentionBackendV1
elif backend == _Backend.HPU_ATTN_V1:
logger.info("Using HPUAttentionV1 backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.PALLAS:
Expand Down Expand Up @@ -249,6 +254,10 @@
return _Backend.ROCM_FLASH

if current_platform.is_hpu():
if selected_backend != _Backend.HPU_ATTN and selected_backend != _Backend.HPU_ATTN_V1:

Check failure on line 257 in vllm/attention/selector.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/selector.py:257:81: E501 Line too long (94 > 80)
logger.info("Cannot use %s backend on HPU.", selected_backend)
if use_v1:
return _Backend.HPU_ATTN_V1
return _Backend.HPU_ATTN

if use_v1:
Expand Down
289 changes: 289 additions & 0 deletions vllm/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu

logger = init_logger(__name__)


class HPUAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUAttentionMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.

The prompts might have different lengths, while the generation tokens
always have length 1.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
if self.prefill_use_fusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
softmax_op=self.softmax,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
) -> torch.Tensor:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
8 changes: 4 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.executor.gpu_executor import GPUExecutor
#from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.version import __version__ as VLLM_VERSION

Expand All @@ -34,7 +34,7 @@ class EngineCore:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Any],
usage_context: UsageContext,
):
# Override the configs for V1.
Expand Down Expand Up @@ -124,7 +124,7 @@ class EngineCoreProc(EngineCore):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Any],
usage_context: UsageContext,
input_path: str,
output_path: str,
Expand Down Expand Up @@ -209,7 +209,7 @@ def wait_for_startup(
@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Any],
usage_context: UsageContext,
input_path: str,
output_path: str,
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Mapping, Optional, Type, Union
from typing import Any, Dict, List, Mapping, Optional, Type, Union

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
Expand All @@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand All @@ -16,7 +17,6 @@
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor

logger = init_logger(__name__)

Expand All @@ -27,7 +27,7 @@ class LLMEngine:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Any],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
Expand Down Expand Up @@ -92,6 +92,10 @@ def from_engine_args(

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
if current_platform.is_hpu():
from vllm.v1.executor.hpu_executor import HPUExecutor
return HPUExecutor
from vllm.v1.executor.gpu_executor import GPUExecutor
return GPUExecutor

def stop_remote_worker_execution_loop(self) -> None:
Expand Down
Loading
Loading