Skip to content

Commit

Permalink
Fix: attn make_metadata with kv_store_meta
Browse files Browse the repository at this point in the history
Signed-off-by: Dahai Tang <[email protected]>
  • Loading branch information
Dahai Tang committed Dec 5, 2024
1 parent 068342c commit 8d2816b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 5 deletions.
9 changes: 7 additions & 2 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.store.kv_store import KVStoreMeta
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

Expand Down Expand Up @@ -925,7 +926,9 @@ def make_test_metadata(
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
cross_kv_mmap.block_tables),
kv_store_meta=KVStoreMeta.null(),
)

else: # not is_prompt
# Decode-phase scenario
Expand Down Expand Up @@ -975,7 +978,9 @@ def make_test_metadata(
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
cross_kv_mmap.block_tables),
kv_store_meta=KVStoreMeta.null(),
)


def assert_actual_matches_ideal(test_params: PhaseTestParameters,
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def graph_capture_get_metadata_for_batch(
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
prefill_wrapper=None,
kv_store_meta=KVStoreMeta.null(),
)
attn_metadata.begin_forward()
return attn_metadata

Expand Down
8 changes: 6 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.store.kv_store import KVStoreMeta
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
Expand Down Expand Up @@ -888,7 +889,8 @@ def _prepare_prompt(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
None, # FIXME(kzawora): mutli-modality will not work here
kv_store_meta=KVStoreMeta.null(),
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)

Expand Down Expand Up @@ -1042,7 +1044,9 @@ def _prepare_decode(
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None)
multi_modal_placeholder_index_maps=None,
kv_store_meta=KVStoreMeta.null(),
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.store.kv_store import KVStoreMeta
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -183,6 +184,7 @@ def _dummy_run(
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
kv_store_meta=KVStoreMeta.null(),
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
Expand All @@ -205,6 +207,7 @@ def _dummy_run(
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
kv_store_meta=KVStoreMeta.null(),
)
else:
assert seq_len == 1
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.store.kv_store import KVStoreMeta
from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -270,6 +271,7 @@ def _prepare_prompt(
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
kv_store_meta=KVStoreMeta.null(),
)

multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
Expand Down Expand Up @@ -354,6 +356,7 @@ def _prepare_decode(
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
kv_store_meta=KVStoreMeta.null(),
)
return (
input_tokens,
Expand Down

0 comments on commit 8d2816b

Please sign in to comment.