diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 469d18a4dd7af..286aa004798f1 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -19,19 +19,19 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-3.2-1B", + #"meta-llama/Llama-3.2-1B", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("chunked_prefill_token_size", [4,]) +@pytest.mark.parametrize("enforce_eager", [True]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +#@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -42,14 +42,12 @@ def test_models( chunked_prefill_token_size: int, enforce_eager: bool, tensor_parallel_size: int, - attention_backend: str, - monkeypatch, ) -> None: """ Checks exact match decode between huggingface model and vllm runner with chunked prefill. """ - override_backend_env_variable(monkeypatch, attention_backend) + #override_backend_env_variable(monkeypatch, attention_backend) max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size @@ -76,7 +74,7 @@ def test_models( ) -@multi_gpu_test(num_gpus=2) +'''@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) @@ -320,4 +318,4 @@ def test_with_prefix_caching_cpu( chunk_size, 1, dtype, - ) + )''' diff --git a/tests/conftest.py b/tests/conftest.py index 9365b52dc74e1..7b4cb4c066b5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,7 +246,7 @@ def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: return x if device is None: - device = "cpu" if current_platform.is_cpu() else "cuda" + device = "cpu" if current_platform.is_cpu() or current_platform.is_hpu() else "cuda" if isinstance(x, dict): return {k: self.wrap_device(v, device) for k, v in x.items()} diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 96dafe8c2fcb1..5f927a4edb5d6 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -19,6 +19,8 @@ from vllm.logger import init_logger from vllm.utils import is_fake_hpu +import habana_frameworks.torch as htorch + logger = init_logger(__name__) HPUFusedSDPA = None @@ -95,6 +97,8 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): cross_block_scales: Optional[torch.Tensor] = None cross_block_usage: Optional[torch.Tensor] = None cross_attn_bias: Optional[torch.Tensor] = None + decode_slot_mapping: Optional[torch.Tensor] = None + decode_block_list: Optional[torch.Tensor] = None class HPUAttentionImpl(AttentionImpl, torch.nn.Module): @@ -202,31 +206,99 @@ def forward( v_scale=v_scale, ) - batch_size, seq_len, hidden_size = query.shape - _, seq_len_kv, _ = key.shape - - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - block_indices = attn_metadata.block_indices - block_offsets = attn_metadata.block_offsets - if attn_metadata.is_prompt: - key = key.unflatten(0, (block_indices.size(0), -1)) - value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: - key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) - - if attn_metadata.is_prompt: + hidden_size: int = 0 + prefill_query = query[:attn_metadata.num_prefill_tokens].clone() + prefill_key = key[:attn_metadata.num_prefill_tokens].clone() + prefill_value = value[:attn_metadata.num_prefill_tokens].clone() + decode_query = query[attn_metadata.num_prefill_tokens:].clone() + decode_key = key[attn_metadata.num_prefill_tokens:].clone() + decode_value = value[attn_metadata.num_prefill_tokens:].clone() + htorch.core.mark_step() + if attn_metadata.num_decode_tokens > 0: + import pdb; pdb.set_trace() + if attn_metadata.num_prefill_tokens > 0: + # prefill preprocessing + hidden_size = prefill_query.shape[-1] + # print(prefill_query.shape, hidden_size) + prefill_query = prefill_query.reshape(attn_metadata.num_prefills, + attn_metadata.num_prefill_tokens // attn_metadata.num_prefills, + hidden_size) + hidden_size = prefill_key.shape[-1] + # print(prefill_key.shape, hidden_size) + prefill_key = prefill_key.reshape(attn_metadata.num_prefills,exit + attn_metadata.num_prefill_tokens // attn_metadata.num_prefills, + hidden_size) + hidden_size = prefill_value.shape[-1] + # print(prefill_value.shape, hidden_size) + prefill_value = prefill_value.reshape(attn_metadata.num_prefills, + attn_metadata.num_prefill_tokens // attn_metadata.num_prefills, + hidden_size) + prefill_batch_size, prefill_seq_len, prefill_hidden_size = prefill_query.shape + _, seq_len_kv, _ = prefill_key.shape + prefill_query = prefill_query.reshape(-1, self.num_heads, self.head_size) + prefill_key = prefill_key.reshape(-1, self.num_kv_heads, self.head_size) + prefill_value = prefill_value.reshape(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + prefill_key = prefill_key.unflatten(0, (block_indices.size(0), -1)) + prefill_value = prefill_value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + prefill_key_cache = self.k_cache(prefill_key, key_cache, block_indices, + block_offsets) + prefill_value_cache = self.v_cache(prefill_value, value_cache, block_indices, + block_offsets) + htorch.core.mark_step() + if attn_metadata.num_decode_tokens > 0: + # decode preprocessing + import pdb; pdb.set_trace() + hidden_size = decode_query.shape[-1] + print(decode_query.shape, hidden_size) + decode_query = decode_query.reshape(attn_metadata.num_decode_tokens, + 1, hidden_size) + hidden_size = decode_key.shape[-1] + print(decode_key.shape, hidden_size) + decode_key = decode_key.reshape(attn_metadata.num_decode_tokens, + 1, hidden_size) + hidden_size = decode_value.shape[-1] + print(decode_value.shape, hidden_size) + decode_value = decode_value.reshape(attn_metadata.num_decode_tokens, + 1, hidden_size) + decode_batch_size, decode_seq_len, decode_hidden_size = decode_query.shape + decode_query = decode_query.view(-1, self.num_heads, self.head_size) + decode_key = decode_key.view(-1, self.num_kv_heads, self.head_size) + decode_value = decode_value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + decode_key_cache = self.k_cache(decode_key, key_cache, block_indices, + block_offsets) + decode_value_cache = self.v_cache(decode_value, value_cache, block_indices, + block_offsets) + htorch.core.mark_step() + import pdb; pdb.set_trace() + + + prompt_output: torch.Tensor = None + decode_output: torch.Tensor = None + batch_size: int = 0 + seq_len: int = 0 + if attn_metadata.num_prefills > 0: # Prompt run. + batch_size = prefill_batch_size + seq_len = prefill_seq_len + hidden_size = prefill_hidden_size query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) @@ -247,9 +319,9 @@ def forward( attn_bias = None out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), + prefill_query.view(query_shape), + prefill_key.view(kv_shape), + prefill_value.view(kv_shape), attn_bias=attn_bias, p=0.0, scale=self.scale, @@ -262,11 +334,11 @@ def forward( else: # TODO: enable FusedSDPA out = HPUPagedAttention.forward_prefix( - query=query.view(query_shape), - key=key.view(kv_shape), - value=value.view(kv_shape), - key_cache=key_cache, - value_cache=value_cache, + query=prefill_query.view(query_shape), + key=prefill_key.view(kv_shape), + value=prefill_value.view(kv_shape), + key_cache=prefill_key_cache, + value_cache=prefill_value_cache, block_list=attn_metadata.block_list, attn_bias=attn_metadata.attn_bias, scale=self.scale, @@ -275,14 +347,22 @@ def forward( softmax_op=self.softmax, keys_fetch_func=self.k_cache.fetch_from_cache, values_fetch_func=self.v_cache.fetch_from_cache) - output = out.reshape(batch_size, seq_len, hidden_size) - else: + prompt_output = out.reshape(batch_size, seq_len, hidden_size) + htorch.core.mark_step() + if attn_metadata.num_decode_tokens > 0: # Decoding run. - output = HPUPagedAttention.forward_decode( + import pdb; pdb.set_trace() + query = decode_query + key = decode_key + value = decode_value + batch_size = decode_batch_size + seq_len = decode_seq_len + hidden_size = decode_hidden_size + decode_output = HPUPagedAttention.forward_decode( query=query, - key_cache=key_cache, - value_cache=value_cache, - block_list=attn_metadata.block_list, + key_cache=decode_key_cache, + value_cache=decode_value_cache, + block_list=attn_metadata.decode_block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, @@ -294,8 +374,16 @@ def forward( block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, values_fetch_func=self.v_cache.fetch_from_cache) + htorch.core.mark_step() # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + if not decode_output: + return prompt_output.view(batch_size * seq_len, hidden_size) + elif not prompt_output: + return decode_output.view(batch_size * seq_len, hidden_size) + else: + prompt_output = prompt_output.view(batch_size * seq_len, hidden_size) + decode_output = decode_output.view(batch_size * seq_len, hidden_size) + return torch.cat((prompt_output, decode_output)) def forward_encoder_decoder( self, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9f932c6f26eaa..1d572833fcb34 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -119,7 +119,7 @@ class EngineArgs: enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False use_v2_block_manager: bool = True - use_padding_aware_scheduling: bool = current_platform.is_hpu() + use_padding_aware_scheduling: bool = False swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 8128028bd2ab8..270002e48ca90 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -291,8 +291,7 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None and len( - query_lens) > 0 else 1 + query_len = 1 sample_len = len(seq_ids) * query_len if do_sample else 0 if sampling_params.seed is not None and generators is not None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..d5420a29054b6 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -326,30 +326,30 @@ def _set_block_scales(self, metadata, device): metadata = metadata._replace(block_scales=block_scales) return metadata - def _set_indices_and_offsets(self, metadata, block_size, is_prompt): + def _set_indices_and_offsets(self, metadata, block_size): slot_mapping = metadata.slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: + if metadata.num_prefill_tokens > 0: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None - else: - offsets = torch.fmod(slot_mapping, block_size) + if metadata.num_decode_tokens > 0: + decode_slot_mapping = metadata.decode_slot_mapping.flatten() + offsets = torch.fmod(decode_slot_mapping, block_size) metadata = metadata._replace(block_offsets=offsets, block_indices=indices) return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + def _update_metadata(self, attn_metadata, device, dtype): - if attn_metadata.is_prompt: - attn_metadata = self._set_attn_bias(attn_metadata, batch_size, - seq_len, device, dtype) - else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, + if attn_metadata.num_prefills > 0: + attn_metadata = self._set_attn_bias(attn_metadata, attn_metadata.num_prefills, + attn_metadata.num_prefill_tokens / attn_metadata.num_prefills, device, dtype) + if attn_metadata.num_decode_tokens > 0: + attn_metadata = self._set_block_mapping(attn_metadata, attn_metadata.num_decode_tokens, device, dtype) attn_metadata = self._set_block_scales(attn_metadata, device) attn_metadata = self._set_indices_and_offsets(attn_metadata, - self.block_size, - attn_metadata.is_prompt) + self.block_size) return attn_metadata def _prepare_cos_sin(self, positions): @@ -372,8 +372,7 @@ def forward(self, *args, **kwargs): kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( - kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + kwargs['attn_metadata'], input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) @@ -910,7 +909,7 @@ def _prepare_prompt( if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if any(context_lens): - assert not self.scheduler_config.chunked_prefill_enabled + # assert not self.scheduler_config.chunked_prefill_enabled # prefix caching max_num_block = max(len(bt) for bt in prefix_block_tables) @@ -960,7 +959,7 @@ def _prepare_prompt( # Note: num_prefill_tokens is calculated using the length of # input_tokens after padding. num_prefill_tokens = input_tokens_tensor.numel() - if prefix_block_list_tensor: + if prefix_block_list_tensor is not None: prefix_block_list_tensor = prefix_block_list_tensor.to( self.device, non_blocking=True) input_tokens_tensor = input_tokens_tensor.to( # type: ignore @@ -1184,6 +1183,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[TModelInputForHPU, SamplingMetadata]: + print('prepare_input_tensors') if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None @@ -1204,7 +1204,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + batch_size_padded = real_batch_size + '''batch_size_padded = self.bucketing_ctx.get_padded_batch_size( real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() @@ -1212,7 +1213,7 @@ def prepare_input_tensors( dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( 0, 0, is_prompt) seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) + for _ in range(batch_size_padding))''' prefill_reqs = [] decode_reqs = [] @@ -1261,21 +1262,34 @@ def prepare_input_tensors( # NOTE(kzawora): Here we diverge from GPU code - we don't # support mixed batches, so we either use decode or prefill # inputs, without coalescing. - assert (num_prefills == 0 and num_decode_tokens > 0) or ( + '''if num_decode_tokens > 0: + import pdb; pdb.set_trace()''' + '''assert (num_prefills == 0 and num_decode_tokens > 0) or ( num_prefills > 0 - and num_decode_tokens == 0), "HPU does not support mixed batches!" + and num_decode_tokens == 0), "HPU does not support mixed batches!"''' if num_decode_tokens > 0: - input_tokens = decode_input_tokens + '''input_tokens = decode_input_tokens input_positions = decode_input_positions - slot_mapping = decode_slot_mapping + slot_mapping = decode_slot_mapping''' lora_index_mapping = decode_lora_index_mapping lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests lora_ids = decode_lora_ids - + if num_prefills > 0: + max_len = input_tokens.size(1) + input_tokens = input_tokens.flatten() + input_positions = input_positions.flatten() + if num_decode_tokens > 0: + decode_input_tokens = decode_input_tokens.flatten() + decode_input_positions = decode_input_positions.flatten() + input_tokens = torch.cat((input_tokens, decode_input_tokens), dim=0) + input_positions = torch.cat((input_positions, decode_input_positions), dim=0) + else: + max_len = decode_input_tokens.size(1) + input_tokens = decode_input_tokens.flatten() + input_positions = decode_input_positions.flatten() # FIXME: We need to adjust selected_token_indices to accommodate # for padding - max_len = input_tokens.size(1) paddings = [max_len - q for q in query_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) @@ -1301,7 +1315,6 @@ def prepare_input_tensors( if (prefill_attn_metadata is not None and decode_attn_metadata is not None): batch_type = BatchType.MIXED - raise NotImplementedError("Mixed batch is not supported on HPU") elif prefill_attn_metadata is not None: batch_type = BatchType.PREFILL else: @@ -1328,6 +1341,20 @@ def prepare_input_tensors( assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + #input_tokens = input_tokens.flatten() + if prefill_attn_metadata: + attn_metadata = prefill_attn_metadata + if decode_attn_metadata: + attn_metadata.num_decode_tokens = decode_attn_metadata.num_decode_tokens + attn_metadata.decode_slot_mapping = decode_attn_metadata.slot_mapping + attn_metadata.decode_block_list = decode_attn_metadata.block_list + attn_metadata.block_usage = decode_attn_metadata.block_usage + attn_metadata.block_groups = decode_attn_metadata.block_groups + else: + attn_metadata = decode_attn_metadata + attn_metadata.decode_slot_mapping = decode_attn_metadata.slot_mapping + attn_metadata.decode_block_list = decode_attn_metadata.block_list + attn_metadata = prefill_attn_metadata if \ prefill_attn_metadata is not None else decode_attn_metadata @@ -1348,7 +1375,7 @@ def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: - return attn_metadata.block_list.numel() + return attn_metadata.decode_block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: @@ -1384,6 +1411,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_offsets', 'block_scales', 'block_groups', + 'num_prefills', + 'num_prefill_tokens', + 'num_decode_tokens', + 'decode_slot_mapping', + 'decode_block_list', ]) return attention_metadata @@ -1435,7 +1467,8 @@ def warmup_scenario(self, is_pt_profiler_run=False, is_lora_profile_run=False, temperature=0) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + return + '''use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" @@ -1523,7 +1556,7 @@ def warmup_scenario(self, if profiler: profiler.stop() self.profiler.end() - gc.collect() + gc.collect()''' def remove_all_loras(self): if not self.lora_manager: