Skip to content

Commit

Permalink
[Frontend] Add logits_processors as an extra completion argument (#…
Browse files Browse the repository at this point in the history
…11150)

Signed-off-by: Brad Hilton <[email protected]>
  • Loading branch information
bradhilton authored Dec 14, 2024
1 parent 3cb5769 commit 9c3dadd
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 39 deletions.
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MockModelConfig:
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None


@dataclass
Expand Down
71 changes: 38 additions & 33 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,41 +156,45 @@ class ModelConfig:
can not be gathered from the vllm arguments.
override_pooler_config: Initialize non default pooling config or
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
"""

def __init__(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
def __init__(self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -316,6 +320,7 @@ def __init__(
self.task: Final = task

self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern

self._verify_quantization()
self._verify_cuda_graph()
Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class EngineArgs:
enable_chunked_prefill: Optional[bool] = None

guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
Expand Down Expand Up @@ -374,6 +375,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
Expand Down Expand Up @@ -975,7 +984,7 @@ def create_model_config(self) -> ModelConfig:
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
logits_processor_pattern=self.logits_processor_pattern)

def create_load_config(self) -> LoadConfig:
return LoadConfig(
Expand Down
77 changes: 74 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import re
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
Expand All @@ -14,7 +15,7 @@
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid
from vllm.utils import random_uuid, resolve_obj_by_qualname

logger = init_logger(__name__)

Expand Down Expand Up @@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function"


class LogitsProcessorConstructor(BaseModel):
qualname: str
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None


LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]


def get_logits_processors(processors: Optional[LogitsProcessors],
pattern: Optional[str]) -> Optional[List[Any]]:
if processors and pattern:
logits_processors = []
for processor in processors:
qualname = processor if isinstance(processor,
str) else processor.qualname
if not re.match(pattern, qualname):
raise ValueError(
f"Logits processor '{qualname}' is not allowed by this "
"server. See --logits-processor-pattern engine argument "
"for more information.")
try:
logits_processor = resolve_obj_by_qualname(qualname)
except Exception as e:
raise ValueError(
f"Logits processor '{qualname}' could not be resolved: {e}"
) from e
if isinstance(processor, LogitsProcessorConstructor):
logits_processor = logits_processor(*processor.args or [],
**processor.kwargs or {})
logits_processors.append(logits_processor)
return logits_processors
elif processors:
raise ValueError(
"The `logits_processors` argument is not supported by this "
"server. See --logits-processor-pattern engine argugment "
"for more information.")
return None


class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))

# doc: end-chat-completion-extra-params

Expand All @@ -314,7 +366,9 @@ def to_beam_search_params(self,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
Expand Down Expand Up @@ -364,6 +418,8 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
Expand Down Expand Up @@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))

# doc: end-completion-extra-params

Expand All @@ -619,7 +686,9 @@ def to_beam_search_params(self,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
Expand Down Expand Up @@ -665,6 +734,8 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ async def create_chat_completion(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
default_max_tokens,
self.model_config.logits_processor_pattern)

self._log_inputs(request_id,
request_prompts[i],
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def create_completion(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
default_max_tokens,
self.model_config.logits_processor_pattern)

request_id_item = f"{request_id}-{i}"

Expand Down

0 comments on commit 9c3dadd

Please sign in to comment.