Skip to content

Commit

Permalink
Fix model_output_idx on HPU (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored May 9, 2024
1 parent b5d4037 commit 90dfa92
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def _prepare_seq_groups(
# Total number of prompts from given sequence groups.
num_prompts = 0

# FIXME: On HPU prompts are right-padded. We need to take that into account
# when updating model_output_idx
if is_hpu() and len(seq_lens) > 0:
assert seq_lens == query_lens, 'Prompt chunking is not yet supported on HPU!'
max_seq_len = max(seq_lens)

for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
Expand Down Expand Up @@ -219,10 +225,12 @@ def _prepare_seq_groups(
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
padding_len = 0 if not is_hpu() else max_seq_len - seq_len
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
padding_len = 0

# Update indices to select from the model output.
"""
Expand All @@ -241,6 +249,7 @@ def _prepare_seq_groups(
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
model_output_idx += sample_len
model_output_idx += padding_len

# We now find indices for logprob computation and sampling.
"""
Expand Down

0 comments on commit 90dfa92

Please sign in to comment.