Skip to content

Commit

Permalink
[Bugfix][Frontend] Guard against bad token ids (#9634)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Oct 29, 2024
1 parent 0ad216f commit 67bdf8e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 17 deletions.
8 changes: 7 additions & 1 deletion tests/entrypoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@


def test_empty_prompt():
llm = LLM(model="gpt2")
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])


def test_out_of_vocab_token():
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'):
llm.generate({"prompt_token_ids": [999999]})
18 changes: 9 additions & 9 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
# Added tokens should be rejected by the base model
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)


@pytest.mark.asyncio
Expand Down
15 changes: 15 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,18 @@ async def test_empty_prompt():
prompt="",
max_tokens=5,
temperature=0.0)


@pytest.mark.asyncio
async def test_out_of_vocab_token_ids():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.BadRequestError,
match=re.compile('.*out of vocabulary.*')):
await client.completions.create(model=model_name,
prompt=[999999],
max_tokens=5,
temperature=0.0)
15 changes: 12 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

@overload # DEPRECATED
async def add_request_async(
self,
Expand Down Expand Up @@ -472,6 +478,10 @@ async def add_request_async(
if arrival_time is None:
arrival_time = time.time()

if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)

preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
request_id=request_id,
Expand All @@ -488,7 +498,7 @@ async def add_request_async(
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)

Expand Down Expand Up @@ -715,8 +725,7 @@ async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
return await self.engine.get_tokenizer_async(lora_request)

def start_background_loop(self) -> None:
"""Start the background loop."""
Expand Down
40 changes: 36 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Set, Type, Union, cast, overload

import torch
from typing_extensions import TypeVar
from typing_extensions import TypeIs, TypeVar

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
Expand All @@ -32,7 +32,8 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
Expand Down Expand Up @@ -667,7 +668,7 @@ def _add_processed_request(
)
return None

self._validate_model_inputs(processed_inputs)
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -829,6 +830,11 @@ def add_request(
if arrival_time is None:
arrival_time = time.time()

if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))

preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
Expand All @@ -855,6 +861,31 @@ def add_request(
priority=priority,
)

def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))

@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt

def _create_sequence_group_with_sampling(
self,
request_id: str,
Expand Down Expand Up @@ -1942,7 +1973,8 @@ def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()

def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
EncoderDecoderInputs],
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
Expand Down
5 changes: 5 additions & 0 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer.get_vocab().values())

class CachedTokenizer(tokenizer.__class__): # type: ignore

Expand All @@ -50,6 +51,10 @@ def all_special_tokens(self):
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended

@property
def max_token_id(self):
return max_token_id

def __len__(self):
return tokenizer_len

Expand Down
5 changes: 5 additions & 0 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

self.tokenizer = tokenizer_
self._max_token_id = max(self._vocab.values())

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -158,6 +159,10 @@ def is_fast(self) -> bool:
def vocab_size(self) -> int:
return len(self._vocab)

@property
def max_token_id(self) -> int:
return self._max_token_id

def __len__(self) -> int:
return self.vocab_size

Expand Down

0 comments on commit 67bdf8e

Please sign in to comment.