Skip to content

Commit

Permalink
[Frontend] Bad words sampling parameter (vllm-project#9717)
Browse files Browse the repository at this point in the history
Signed-off-by: Vasily Alexeev <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
  • Loading branch information
Alvant authored and cooleel committed Oct 28, 2024
1 parent fb35a80 commit c3449f9
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 16 deletions.
185 changes: 185 additions & 0 deletions tests/samplers/test_no_bad_words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Make sure bad_words works.
Run `pytest tests/samplers/test_no_bad_words.py`.
"""
from typing import List, Optional

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams


def _generate(
model: LLM,
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
bad_words: Optional[List[str]] = None,
) -> List[int]:
sampling_params = SamplingParams(
temperature=temperature,
bad_words=bad_words,
)

# [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params)

output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output
# [0] token_ids (not text)
# [0] first (and only) output completion

return output_token_ids


class TestOneTokenBadWord:
MODEL = "TheBloke/Llama-2-7B-fp16"

PROMPT = "Hi! How are"
TARGET_TOKEN = "you"

def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)

self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id = self._encode(self.TARGET_TOKEN,
add_special_tokens=False)[0]

def test_one_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[0] == self.target_token_id

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN])
assert self.target_token_id not in output_token_ids

def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)

def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids


class TestTwoTokenBadWord:
# Another model (with a different tokenizer behaviour)
MODEL = "openai-community/gpt2"

PROMPT = "How old are you? I am 10"
TARGET_TOKEN1 = "years"
TARGET_TOKEN2 = "old"
NEIGHBOUR_TOKEN2 = "older"

def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)

self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
add_special_tokens=False)[0]
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
add_special_tokens=False)[0]
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
add_special_tokens=False)[0]

def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2
]

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN1])
assert self.target_token_id1 not in output_token_ids

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN2])
assert output_token_ids[0] == self.target_token_id1
assert self.target_token_id2 not in output_token_ids

output_token_ids = self._generate(
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
# Model dependent behaviour
assert output_token_ids[:2] == [
self.target_token_id1, self.neighbour_token_id2
]

output_token_ids = self._generate(
llm,
bad_words=[
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
assert output_token_ids[:2] != [
self.target_token_id1, self.neighbour_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.neighbour_token_id2])
assert ((self.target_token_id2 in output_token_ids)
or (self.neighbour_token_id2 in output_token_ids))

def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)

@staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
searched = False

for start in range(len(sequence)):
end = start + len(subsequence)
current_subsequence = sequence[start:end]

if len(current_subsequence) < len(subsequence):
continue

searched = True

assert len(current_subsequence) == len(subsequence)

if current_subsequence == subsequence:
return True

assert searched, "All subsequences did not match in length..."

return False

def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
13 changes: 11 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
Expand Down Expand Up @@ -1963,6 +1965,7 @@ def _build_logits_processors(
logits_processors field. Returns the modified sampling params."""

logits_processors = []

if (guided_decoding := sampling_params.guided_decoding) is not None:

logger.debug(
Expand All @@ -1984,7 +1987,7 @@ def _build_logits_processors(
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)

processors = get_logits_processors(
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
Expand All @@ -1994,6 +1997,12 @@ def _build_logits_processors(
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None

if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)

if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
Expand Down
119 changes: 119 additions & 0 deletions vllm/logits_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Callable, List, Tuple, Union

import torch

from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""


def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
bad_words_ids: List[List[int]] = list()

for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)

# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != bad_words_ids[-1][0]
and len(prompt_token_ids) == len(bad_words_ids[-1])):
bad_words_ids.append(prompt_token_ids)

return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]


class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0

def __init__(self, bad_words_ids: List[List[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None

def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
self._init_word_bias(logits=logits)

last_token_bias = torch.zeros_like(logits)

for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1: # 1-token words already processed
continue

if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue

prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:]
expected_prefix = bad_word_ids[:prefix_length]

assert len(actual_prefix) == len(expected_prefix)

is_match = tuple(actual_prefix) == tuple(expected_prefix)
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
else self._NEUTRAL_LOGIT)

logits = logits + self.word_bias + last_token_bias

return logits

def _init_word_bias(self, logits: torch.FloatTensor) -> None:
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py

vocab_size = logits.shape[-1]

self._check_token_ids_bounds(vocab_size=vocab_size)

self.word_bias = torch.zeros((vocab_size, ),
dtype=torch.float,
device=logits.device)

for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1:
bad_word_id = bad_word_ids[-1]
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT

def _check_token_ids_bounds(self, vocab_size: int) -> None:
invalid_token_ids = []

for bad_word_ids in self.bad_words_ids:
for token_id in bad_word_ids:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)

if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {vocab_size},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id < {vocab_size}.")
3 changes: 2 additions & 1 deletion vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams


async def get_guided_decoding_logits_processor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from transformers import PreTrainedTokenizerBase

from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams


def get_local_lm_format_enforcer_guided_decoding_logits_processor(
Expand Down
Loading

0 comments on commit c3449f9

Please sign in to comment.