Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Jul 31, 2024
1 parent 68ada2d commit 5069fb7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
24 changes: 16 additions & 8 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor]
) -> SamplingParams:
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
Expand All @@ -244,7 +248,7 @@ def to_sampling_params(
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
Expand Down Expand Up @@ -399,9 +403,13 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor]
) -> SamplingParams:
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = get_logits_processors(
Expand All @@ -427,7 +435,7 @@ def to_sampling_params(
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ async def create_chat_completion(
)

sampling_params = request.to_sampling_params(
tokenizer, guided_decode_logits_processor)
if sampling_params.max_tokens is None:
sampling_params.max_tokens = \
self.max_model_len - len(prompt_inputs["prompt_token_ids"])
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

self._log_inputs(request_id,
prompt_inputs,
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ async def create_completion(self, request: CompletionRequest,

for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer, guided_decode_logits_processor)
if sampling_params.max_tokens is None:
sampling_params.max_tokens = self.max_model_len - \
len(prompt_inputs["prompt_token_ids"])
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

request_id_item = f"{request_id}-{i}"

Expand Down

0 comments on commit 5069fb7

Please sign in to comment.