Skip to content

Commit

Permalink
fixed prefill error
Browse files Browse the repository at this point in the history
  • Loading branch information
kakao-kevin-us committed Oct 23, 2024
1 parent 708e3c4 commit bf18a77
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
10 changes: 8 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def _add_seq_group(
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]

print(f"prefix cache hit: {prefix_cache_hit}")
print(f"chunked prefill enabled: {chunked_prefill_enabled}")
print(f"prompt: {is_prompt}")
print(f"block table: {block_table}")
self.block_tables.append(block_table)

# Compute slot mapping.
Expand Down Expand Up @@ -400,6 +405,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
print(f"inter_data: {inter_data}")
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
Expand All @@ -426,8 +432,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_seqs, self.block_tables)
else:
print(f"block tables: {self.block_tables}")
if self.block_tables[0] is None:
self.block_tables = [list() for _ in range(num_seqs)]
# if self.block_tables[0] is None:
# self.block_tables = [list() for _ in range(num_seqs)]
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""

def __init__(self, pooling_type: PoolingType, normalize: bool):
def __init__(self, pooling_type: PoolingType, normalize: bool, softmax: bool = False):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax

def forward(
self,
Expand Down Expand Up @@ -63,6 +64,8 @@ def forward(

if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)

pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/models/qwen2_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,13 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
print(f"config: {config}\ncache_config: {cache_config}\nquant_config: {quant_config}")
self.model = Qwen2Model(config, cache_config, quant_config)

self.score = ColumnParallelLinear(config.hidden_size,
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False, softmax=True)

def forward(
self,
Expand All @@ -100,10 +101,11 @@ def forward(
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
print(f"{input_ids}\n{positions}\n{kv_caches}\n{attn_metadata}\n{intermediate_tensors}")
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
hidden_states = hidden_states[0]
logits, _ = self.score(hidden_states)
print(logits)
return logits

def pooler(
Expand Down Expand Up @@ -135,6 +137,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
print(f"bias is ignored: {name}")
continue
if is_pp_missing_parameter(name, self):
continue
Expand All @@ -145,6 +148,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
print(f"bias is ignored: {name}")
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
Expand Down

0 comments on commit bf18a77

Please sign in to comment.