From bf18a77caf833184dee9a4a6b3f9083d07a8819f Mon Sep 17 00:00:00 2001 From: kevin-us Date: Wed, 23 Oct 2024 16:20:05 +0900 Subject: [PATCH] fixed prefill error --- vllm/attention/backends/flash_attn.py | 10 ++++++++-- vllm/model_executor/layers/pooler.py | 5 ++++- vllm/model_executor/models/qwen2_cls.py | 10 +++++++--- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4cb29e3462d8b..91e70ada2e59f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -349,6 +349,11 @@ def _add_seq_group( else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] + + print(f"prefix cache hit: {prefix_cache_hit}") + print(f"chunked prefill enabled: {chunked_prefill_enabled}") + print(f"prompt: {is_prompt}") + print(f"block table: {block_table}") self.block_tables.append(block_table) # Compute slot mapping. @@ -400,6 +405,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], for inter_data in self.input_builder.inter_data_list ]) for inter_data in self.input_builder.inter_data_list: + print(f"inter_data: {inter_data}") self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) @@ -426,8 +432,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs, self.block_tables) else: print(f"block tables: {self.block_tables}") - if self.block_tables[0] is None: - self.block_tables = [list() for _ in range(num_seqs)] + # if self.block_tables[0] is None: + # self.block_tables = [list() for _ in range(num_seqs)] block_tables = make_tensor_with_pad( self.block_tables, pad=0, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282f..fcb18767aaf83 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -28,11 +28,12 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, pooling_type: PoolingType, normalize: bool): + def __init__(self, pooling_type: PoolingType, normalize: bool, softmax: bool = False): super().__init__() self.pooling_type = pooling_type self.normalize = normalize + self.softmax = softmax def forward( self, @@ -63,6 +64,8 @@ def forward( if self.normalize: pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if self.softmax: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index cf4cb8467d5cc..5286b70b1cb3e 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -85,12 +85,13 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config + print(f"config: {config}\ncache_config: {cache_config}\nquant_config: {quant_config}") self.model = Qwen2Model(config, cache_config, quant_config) - self.score = ColumnParallelLinear(config.hidden_size, + self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config) - self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False, softmax=True) def forward( self, @@ -100,10 +101,11 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: + print(f"{input_ids}\n{positions}\n{kv_caches}\n{attn_metadata}\n{intermediate_tensors}") hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - hidden_states = hidden_states[0] logits, _ = self.score(hidden_states) + print(logits) return logits def pooler( @@ -135,6 +137,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: + print(f"bias is ignored: {name}") continue if is_pp_missing_parameter(name, self): continue @@ -145,6 +148,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: + print(f"bias is ignored: {name}") continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict)