Skip to content

Commit

Permalink
attention: add AttentionState abstraction (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 4, 2024
1 parent 82eabb6 commit 1405051
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 247 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/aphrodite-workspace
python3 -m pip install dist/*.whl --verbose

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu124torch2.4-cp310-cp310-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### Aphrodite installation IMAGE ####################


Expand Down
2 changes: 2 additions & 0 deletions aphrodite/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from aphrodite.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState,
AttentionType)
from aphrodite.attention.layer import Attention
from aphrodite.attention.selector import get_attn_backend
Expand All @@ -11,5 +12,6 @@
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
49 changes: 48 additions & 1 deletion aphrodite/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Expand All @@ -8,7 +9,7 @@

if TYPE_CHECKING:
from aphrodite.task_handler.model_runner_base import (
ModelRunnerInputBuilderBase)
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)


class AttentionType(Enum):
Expand All @@ -35,6 +36,10 @@ def get_impl_cls() -> Type["AttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@staticmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
Expand Down Expand Up @@ -127,6 +132,48 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionState(ABC, Generic[T]):
"""Holds attention backend specific objects reused during the
lifetime of the model runner.
"""

@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
...

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing a CUDA graph."""
yield

@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...

@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...

@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...

@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...

@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
"""Prepare state for forward pass."""
...


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

Expand Down
7 changes: 6 additions & 1 deletion aphrodite/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
AttentionImpl,
AttentionMetadata,
AttentionType)
from aphrodite.attention.backends.utils import CommonMetadataBuilder
from aphrodite.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from aphrodite.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from aphrodite.attention.ops.paged_attn import PagedAttention
Expand Down Expand Up @@ -100,6 +101,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AttentionMetadataBuilder,
AttentionType)
from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
Expand Down Expand Up @@ -145,6 +146,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
160 changes: 160 additions & 0 deletions aphrodite/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type

try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper

import aphrodite.attention.backends.flash_attn # noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch

Expand All @@ -17,6 +22,7 @@
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState,
AttentionType)
from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
compute_slot_mapping,
Expand Down Expand Up @@ -48,6 +54,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder

@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -77,6 +87,156 @@ def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]


class FlashInferState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer

def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper

def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state

def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
return attn_metadata

def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
}

def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
return

def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()


@dataclass
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
Expand Down
7 changes: 6 additions & 1 deletion aphrodite/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
AttentionImpl,
AttentionMetadata,
AttentionType)
from aphrodite.attention.backends.utils import CommonMetadataBuilder
from aphrodite.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from aphrodite.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)

Expand All @@ -35,6 +36,10 @@ def get_metadata_cls() -> Type["IpexAttnMetadata"]:
def get_builder_cls() -> Type["IpexAttnMetadataBuilder"]:
return IpexAttnMetadataBuilder

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
7 changes: 6 additions & 1 deletion aphrodite/attention/backends/openvino.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Type

import openvino as ov
import torch

from aphrodite.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from aphrodite.attention.backends.utils import CommonAttentionState


class OpenVINOAttentionBackend(AttentionBackend):
Expand All @@ -24,6 +25,10 @@ def get_impl_cls():
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AttentionImpl,
AttentionMetadata,
AttentionType)
from aphrodite.attention.backends.utils import CommonAttentionState


class PallasAttentionBackend(AttentionBackend):
Expand All @@ -20,6 +21,10 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
Loading

0 comments on commit 1405051

Please sign in to comment.