Skip to content

Commit

Permalink
[Performance] Optimize get_seqs (#7051)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 2, 2024
1 parent 6a11fdf commit 6ce01f3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,5 +700,5 @@ def get_common_computed_block_ids(

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
40 changes: 20 additions & 20 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.seqs = seqs
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.metrics = RequestMetrics(arrival_time=arrival_time,
Expand All @@ -458,25 +459,24 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self._first_seq = next(iter(self.seqs_dict.values()))

@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self._first_seq.prompt
return self.seqs[0].prompt

@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self._first_seq.prompt_token_ids
return self.seqs[0].prompt_token_ids

@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return self._first_seq.multi_modal_data
return self.seqs[0].multi_modal_data

@property
def lora_int_id(self) -> int:
Expand Down Expand Up @@ -512,7 +512,7 @@ def maybe_set_first_token_time(self, time: float) -> None:
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.get_seqs()[0].get_output_len() == 1):
and self.seqs[0].get_output_len() == 1):
self.metrics.first_token_time = time

def maybe_set_first_scheduled_time(self, time: float) -> None:
Expand Down Expand Up @@ -548,9 +548,9 @@ def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
return list(self.seqs_dict.values()) if status is None else [
seq for seq in self.seqs_dict.values() if seq.status == status
]
if status is None:
return self.seqs
return [seq for seq in self.seqs if seq.status == status]

def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
Expand All @@ -559,22 +559,20 @@ def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq

def get_unfinished_seqs(self) -> List[Sequence]:
return [
seq for seq in self.seqs_dict.values() if not seq.is_finished()
]
return [seq for seq in self.seqs if not seq.is_finished()]

def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
return [seq for seq in self.seqs if seq.is_finished()]

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)

def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.get_seqs():
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
Expand All @@ -583,7 +581,7 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if status is None:
return len(self.seqs_dict)
return len(self.seqs)

return len(self.get_seqs(status))

Expand All @@ -602,23 +600,25 @@ def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
self.seqs.append(seq)

def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
seq = self.seqs_dict.pop(seq_id, None)
if seq is None:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
self.seqs.remove(seq)

def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs())
return all(seq.is_finished() for seq in self.seqs)

def is_prefill(self) -> bool:
# Every sequence should be in the same stage.
return self.get_seqs()[0].is_prefill()
return self.seqs[0].is_prefill()

def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs_dict)})")
f"num_seqs={len(self.seqs)})")


class SequenceGroupMetadata:
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
assert prms is not None

# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
Expand Down

0 comments on commit 6ce01f3

Please sign in to comment.