diff --git a/Dockerfile b/Dockerfile index 664b14d3d..edf17ff5b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -118,7 +118,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/aphrodite-workspace python3 -m pip install dist/*.whl --verbose RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu124torch2.4-cp310-cp310-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl #################### Aphrodite installation IMAGE #################### diff --git a/aphrodite/attention/__init__.py b/aphrodite/attention/__init__.py index c00a1fe5b..184f4e78d 100644 --- a/aphrodite/attention/__init__.py +++ b/aphrodite/attention/__init__.py @@ -1,6 +1,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, + AttentionState, AttentionType) from aphrodite.attention.layer import Attention from aphrodite.attention.selector import get_attn_backend @@ -11,5 +12,6 @@ "AttentionType", "AttentionMetadataBuilder", "Attention", + "AttentionState", "get_attn_backend", ] diff --git a/aphrodite/attention/backends/abstract.py b/aphrodite/attention/backends/abstract.py index af02f5e7f..023eb7509 100644 --- a/aphrodite/attention/backends/abstract.py +++ b/aphrodite/attention/backends/abstract.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, fields from enum import Enum, auto from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, @@ -8,7 +9,7 @@ if TYPE_CHECKING: from aphrodite.task_handler.model_runner_base import ( - ModelRunnerInputBuilderBase) + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase) class AttentionType(Enum): @@ -35,6 +36,10 @@ def get_impl_cls() -> Type["AttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @staticmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -127,6 +132,48 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionState(ABC, Generic[T]): + """Holds attention backend specific objects reused during the + lifetime of the model runner. + """ + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing a CUDA graph.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], + attn_metadata: T) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" diff --git a/aphrodite/attention/backends/blocksparse_attn.py b/aphrodite/attention/backends/blocksparse_attn.py index 18816adae..e952489a2 100644 --- a/aphrodite/attention/backends/blocksparse_attn.py +++ b/aphrodite/attention/backends/blocksparse_attn.py @@ -7,7 +7,8 @@ AttentionImpl, AttentionMetadata, AttentionType) -from aphrodite.attention.backends.utils import CommonMetadataBuilder +from aphrodite.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from aphrodite.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from aphrodite.attention.ops.paged_attn import PagedAttention @@ -100,6 +101,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: return BlocksparseFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/attention/backends/flash_attn.py b/aphrodite/attention/backends/flash_attn.py index e41c01a75..788d88b1a 100644 --- a/aphrodite/attention/backends/flash_attn.py +++ b/aphrodite/attention/backends/flash_attn.py @@ -11,6 +11,7 @@ AttentionMetadataBuilder, AttentionType) from aphrodite.attention.backends.utils import (PAD_SLOT_ID, + CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) @@ -145,6 +146,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/attention/backends/flashinfer.py b/aphrodite/attention/backends/flashinfer.py index c3d9acf3a..989f05820 100644 --- a/aphrodite/attention/backends/flashinfer.py +++ b/aphrodite/attention/backends/flashinfer.py @@ -1,14 +1,19 @@ +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper import aphrodite.attention.backends.flash_attn # noqa + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -17,6 +22,7 @@ AttentionImpl, AttentionMetadata, AttentionMetadataBuilder, + AttentionState, AttentionType) from aphrodite.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, @@ -48,6 +54,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -77,6 +87,156 @@ def get_supported_head_sizes() -> List[int]: return [64, 128, 256] +class FlashInferState(AttentionState): + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + @dataclass class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding diff --git a/aphrodite/attention/backends/ipex_attn.py b/aphrodite/attention/backends/ipex_attn.py index c2ee78b80..14d17f1a4 100644 --- a/aphrodite/attention/backends/ipex_attn.py +++ b/aphrodite/attention/backends/ipex_attn.py @@ -10,7 +10,8 @@ AttentionImpl, AttentionMetadata, AttentionType) -from aphrodite.attention.backends.utils import CommonMetadataBuilder +from aphrodite.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from aphrodite.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -35,6 +36,10 @@ def get_metadata_cls() -> Type["IpexAttnMetadata"]: def get_builder_cls() -> Type["IpexAttnMetadataBuilder"]: return IpexAttnMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/attention/backends/openvino.py b/aphrodite/attention/backends/openvino.py index 10f8daa70..4acf6b00e 100644 --- a/aphrodite/attention/backends/openvino.py +++ b/aphrodite/attention/backends/openvino.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Type import openvino as ov import torch from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from aphrodite.attention.backends.utils import CommonAttentionState class OpenVINOAttentionBackend(AttentionBackend): @@ -24,6 +25,10 @@ def get_impl_cls(): def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": return OpenVINOAttentionMetadata(*args, **kwargs) diff --git a/aphrodite/attention/backends/pallas.py b/aphrodite/attention/backends/pallas.py index 19a27e524..3463f8584 100644 --- a/aphrodite/attention/backends/pallas.py +++ b/aphrodite/attention/backends/pallas.py @@ -8,6 +8,7 @@ AttentionImpl, AttentionMetadata, AttentionType) +from aphrodite.attention.backends.utils import CommonAttentionState class PallasAttentionBackend(AttentionBackend): @@ -20,6 +21,10 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: def get_metadata_cls() -> Type["PallasMetadata"]: return PallasMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/attention/backends/placeholder_attn.py b/aphrodite/attention/backends/placeholder_attn.py index 09e046a4e..673cf17c1 100644 --- a/aphrodite/attention/backends/placeholder_attn.py +++ b/aphrodite/attention/backends/placeholder_attn.py @@ -7,6 +7,7 @@ AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) +from aphrodite.attention.backends.utils import CommonAttentionState if TYPE_CHECKING: from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder @@ -34,6 +35,10 @@ def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: return PlaceholderAttentionMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/attention/backends/rocm_flash_attn.py b/aphrodite/attention/backends/rocm_flash_attn.py index 7f5f2139a..09bebbf7d 100644 --- a/aphrodite/attention/backends/rocm_flash_attn.py +++ b/aphrodite/attention/backends/rocm_flash_attn.py @@ -10,7 +10,8 @@ AttentionImpl, AttentionMetadata, AttentionType) -from aphrodite.attention.backends.utils import CommonMetadataBuilder +from aphrodite.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from aphrodite.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -29,6 +30,10 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return ROCmFlashAttentionMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: return ROCmFlashAttentionMetadataBuilder diff --git a/aphrodite/attention/backends/torch_sdpa.py b/aphrodite/attention/backends/torch_sdpa.py index 434ce2a82..891ec964b 100644 --- a/aphrodite/attention/backends/torch_sdpa.py +++ b/aphrodite/attention/backends/torch_sdpa.py @@ -10,7 +10,8 @@ AttentionImpl, AttentionMetadata, AttentionType) -from aphrodite.attention.backends.utils import CommonMetadataBuilder +from aphrodite.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from aphrodite.attention.ops.paged_attn import PagedAttentionMetadata from aphrodite.common.utils import is_cpu @@ -36,7 +37,11 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: return TorchSDPAMetadata - + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: return TorchSDPAMetadataBuilder diff --git a/aphrodite/attention/backends/utils.py b/aphrodite/attention/backends/utils.py index 271a52ed0..fa4ba8f97 100644 --- a/aphrodite/attention/backends/utils.py +++ b/aphrodite/attention/backends/utils.py @@ -1,11 +1,16 @@ -from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union import numpy as np import torch -from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder +from aphrodite.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad +if TYPE_CHECKING: + from aphrodite.task_handler.model_runner_base import ModelRunnerBase + # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " @@ -266,4 +271,68 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, + ) + + +class CommonAttentionState(AttentionState): + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch(self, batch_size: int): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, ) + return attn_metadata + + def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: + return { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + + def prepare_graph_input_buffers(self, input_buffers, + attn_metadata) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + def begin_forward(self, model_input) -> None: + return diff --git a/aphrodite/attention/backends/xformers.py b/aphrodite/attention/backends/xformers.py index 4aa875765..83c62461b 100644 --- a/aphrodite/attention/backends/xformers.py +++ b/aphrodite/attention/backends/xformers.py @@ -13,7 +13,8 @@ AttentionImpl, AttentionMetadata, AttentionType) -from aphrodite.attention.backends.utils import CommonMetadataBuilder +from aphrodite.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) from aphrodite.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -36,6 +37,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_builder_cls() -> Type["XFormersMetadataBuilder"]: return XFormersMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/aphrodite/spec_decode/draft_model_runner.py b/aphrodite/spec_decode/draft_model_runner.py index ad856a689..b5d6b8a93 100644 --- a/aphrodite/spec_decode/draft_model_runner.py +++ b/aphrodite/spec_decode/draft_model_runner.py @@ -12,17 +12,6 @@ from aphrodite.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -88,11 +77,6 @@ def __init__( **kwargs, ) - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): @@ -268,36 +252,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = \ - self.graph_runners[model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Detect exec mode assert model_input.attn_metadata is not None diff --git a/aphrodite/task_handler/enc_dec_model_runner.py b/aphrodite/task_handler/enc_dec_model_runner.py index 96f5fcd94..7421679af 100644 --- a/aphrodite/task_handler/enc_dec_model_runner.py +++ b/aphrodite/task_handler/enc_dec_model_runner.py @@ -7,6 +7,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionMetadata) +from aphrodite.attention.backends.utils import PAD_SLOT_ID from aphrodite.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, @@ -23,7 +24,7 @@ from aphrodite.modeling import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from aphrodite.task_handler.model_runner import ( - _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder, + GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) from aphrodite.task_handler.model_runner_base import ( _add_attn_metadata_broadcastable_dict, @@ -387,7 +388,7 @@ def _prepare_encoder_model_input_tensors( # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. - cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): block_number = seq_group_metadata.cross_block_table[ diff --git a/aphrodite/task_handler/model_runner.py b/aphrodite/task_handler/model_runner.py index c29bc39b8..1596f109f 100644 --- a/aphrodite/task_handler/model_runner.py +++ b/aphrodite/task_handler/model_runner.py @@ -15,18 +15,9 @@ import torch.nn as nn from loguru import logger -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - from aphrodite.attention import AttentionMetadata, get_attn_backend +from aphrodite.attention.backends.abstract import AttentionState +from aphrodite.attention.backends.utils import CommonAttentionState from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -34,8 +25,7 @@ from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache, - async_tensor_h2d, flatten_2d_lists, - get_kv_cache_torch_dtype, is_hip, + async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available) from aphrodite.distributed import get_pp_group from aphrodite.distributed.parallel_state import ( @@ -67,7 +57,6 @@ if TYPE_CHECKING: from aphrodite.attention.backends.abstract import AttentionBackend -_PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. @@ -852,6 +841,11 @@ def __init__( self.block_size, self.model_config.is_attention_free(), ) + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry @@ -866,11 +860,6 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None - self.flashinfer_decode_workspace_buffer = None - self.flashinfer_decode_wrapper = None - self.flashinfer_prefill_workspace_buffer = None - self.flashinfer_prefill_wrapper = None - set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -1227,10 +1216,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() - slot_mapping.fill_(_PAD_SLOT_ID) - seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() - block_tables = torch.from_numpy(self.graph_block_tables).cuda() intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1250,102 +1235,18 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - if self.attn_backend.get_name() == "flashinfer": - # For flashinfer, different batch sizes will share the - # same workspace buffer. - decode_workspace_buffer = \ - torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - indices_buffer = torch.empty(max_batch_size * - self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - indptr_buffer = torch.empty(max_batch_size + 1, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(max_batch_size, - dtype=torch.int32, - device=self.device) - - with graph_capture() as graph_capture_context: + with self.attn_state.graph_capture( + max_batch_size), graph_capture() as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): for batch_size in reversed(batch_size_capture_list): - if self.attn_backend.get_name() == "flashinfer": - _indptr_buffer = indptr_buffer[:batch_size + 1] - _last_page_len_buffer = last_page_len_buffer[: - batch_size] - - num_qo_heads = ( - self.model_config.get_num_attention_heads( - self.parallel_config, self.tp_rank)) - num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config, self.tp_rank) - if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True - else: - use_tensor_cores = False - decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, _indptr_buffer, - indices_buffer, _last_page_len_buffer, "NHD", - use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange( - 0, batch_size, dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=slot_mapping[:batch_size], - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len= - paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=True, - decode_wrapper=decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) + attn_metadata = ( + self.attn_state.graph_capture_get_metadata_for_batch( + batch_size)) + if self.lora_config: lora_mapping = LoRAMapping( @@ -1363,17 +1264,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: set(), prompt_adapter_mapping) graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name()) - - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = _indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - _last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper + self.model, self.attn_backend.get_name(), + self.attn_state.graph_clone(batch_size)) capture_inputs = { "input_ids": @@ -1501,36 +1393,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - assert model_input.attn_metadata is not None - assert model_input.input_tokens is not None - if self.flashinfer_decode_workspace_buffer is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - - model_input.attn_metadata.prefill_wrapper = \ - self.flashinfer_prefill_wrapper - if model_input.attn_metadata.use_cuda_graph: - batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = self.graph_runners[ - model_input. - virtual_engine][batch_size].flashinfer_decode_wrapper - else: - model_input.attn_metadata.decode_wrapper = \ - self.flashinfer_decode_wrapper - model_input.attn_metadata.begin_forward() + self.attn_state.begin_forward(model_input) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None @@ -1598,22 +1461,17 @@ def execute_model( class CUDAGraphRunner: - def __init__(self, model: nn.Module, backend_name: str): + def __init__(self, model: nn.Module, backend_name: str, + attn_state: AttentionState): self.model = model self.backend_name = backend_name + self.attn_state = attn_state self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None - self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None - self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None - self.flashinfer_indices_buffer: Optional[torch.Tensor] = None - self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None - self.flashinfer_decode_wrapper: Optional[ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None - @property def graph(self): assert self._graph is not None @@ -1678,25 +1536,13 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. - if self.backend_name == "flashinfer": - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - **kwargs, - } - else: - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": - attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - **kwargs, - } + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + **self.attn_state.get_graph_input_buffers(attn_metadata), + **kwargs, + } if intermediate_inputs is not None: self.input_buffers.update(intermediate_inputs.tensors) if get_pp_group().is_last_rank: @@ -1725,12 +1571,8 @@ def forward( if self.backend_name != "No attention": self.input_buffers["slot_mapping"].copy_( attn_metadata.slot_mapping, non_blocking=True) - if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + self.attn_state.prepare_graph_input_buffers(self.input_buffers, + attn_metadata) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs)