Skip to content

Commit

Permalink
[Frontend] generation_config.json for maximum tokens(vllm-project#12242)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Hendrey <[email protected]>
Signed-off-by: Shangming Cai <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Co-authored-by: shangmingc <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
  • Loading branch information
9 people authored Jan 26, 2025
1 parent a525527 commit 9ddc352
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 9 deletions.
110 changes: 110 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Setting server's max_tokens in the generation_config.json
# lower than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 10 # Setting server-side max_tokens limit
}

# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

# Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Test Case 2: Request's max_tokens set higher than server accepts
req.max_tokens = 15

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 5

# Setting server's max_tokens in the generation_config.json
# higher than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 200 # Setting server-side max_tokens limit
}

# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

# Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 93

# Test Case 2: Request's max_tokens set higher than server accepts
req.max_tokens = 100

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 93

# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 5


def test_serving_chat_could_load_correct_generation_config():

Expand Down
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,18 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
"top_k",
"top_p",
"min_p",
"max_new_tokens",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
# Huggingface definition of max_new_tokens is equivalent
# to vLLM's max_tokens
if "max_new_tokens" in diff_sampling_param:
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
"max_new_tokens")
else:
diff_sampling_param = {}
return diff_sampling_param
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
"will be loaded from the specified folder path. If "
"`max_new_tokens` is specified, then it sets a server-wide limit "
"on the number of output tokens for all requests.")

parser.add_argument("--enable-sleep-mode",
action="store_true",
Expand Down
34 changes: 26 additions & 8 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,17 @@ def to_beam_search_params(
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
Expand All @@ -406,11 +410,16 @@ def to_sampling_params(
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

if default_sampling_params is None:
default_sampling_params = {}

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
Expand Down Expand Up @@ -740,13 +749,17 @@ def to_beam_search_params(
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)

Expand All @@ -764,11 +777,16 @@ def to_sampling_params(
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

if default_sampling_params is None:
default_sampling_params = {}

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
Expand Down

0 comments on commit 9ddc352

Please sign in to comment.