Skip to content

Commit

Permalink
[BugFix] Fix recovery logic for sequence group (vllm-project#2186)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Dec 21, 2023
1 parent 3a4fd5c commit a1b9cb2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
Expand All @@ -122,7 +122,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
Expand All @@ -137,7 +137,7 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table.append(block)

# Assign the block table for each sequence.
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Expand Down
12 changes: 7 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ def _schedule(self) -> SchedulerOutputs:
while self.waiting:
seq_group = self.waiting[0]

assert seq_group.num_seqs() == 1, (
waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand All @@ -161,7 +163,7 @@ def _schedule(self) -> SchedulerOutputs:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand Down Expand Up @@ -317,7 +319,7 @@ def free_finished_seq_groups(self) -> None:

def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING

def _append_slot(
Expand Down

0 comments on commit a1b9cb2

Please sign in to comment.