Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon committed Dec 21, 2024
1 parent 0912e3e commit 5dd4caa
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def make_sampling_metadata(
req_id_output_token_ids: Dict[str, List[int]],
skip_copy: bool = False,
) -> SamplingMetadata:
prompt_tokens_tensor: Optional[torch.Tensor] = None
prompt_token_ids: Optional[torch.Tensor] = None
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
Expand All @@ -346,8 +346,7 @@ def make_sampling_metadata(
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
prompt_tokens_tensor = self._construct_prompt_tokens_tensor(
self.vocab_size, device=self.device)
prompt_token_ids = self._make_prompt_token_ids_tensor()

output_token_ids: List[List[int]] = []

Expand All @@ -372,7 +371,7 @@ def make_sampling_metadata(
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_tokens_tensor,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:self.num_reqs],
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
Expand All @@ -382,30 +381,22 @@ def make_sampling_metadata(
no_penalties=self.no_penalties,
)

def _construct_prompt_tokens_tensor(
self,
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
# use the value of vocab_size as a pad since we don't have a
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = (
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
# TODO - Add a method in vllm/utils.py to pad a numpy array similar
# to make_tensor_with_pad which takes a list and move the logic
# there.
padded_prompts = np.full((self.num_reqs, max_prompt_len),
vocab_size,
dtype=np.int64)
for i in range(self.num_reqs):
padded_prompts[i, :self.num_prompt_tokens[i]] = \
self.token_ids_cpu[i, :self.num_prompt_tokens[i]]
prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts)
if self.pin_memory:
prompt_tokens_cpu_tensor = \
prompt_tokens_cpu_tensor.pin_memory()
prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device,
non_blocking=True)
return prompt_tokens_tensor
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)

@property
def num_reqs(self) -> int:
Expand Down

0 comments on commit 5dd4caa

Please sign in to comment.