forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Bad words sampling parameter (vllm-project#9717)
Signed-off-by: Vasily Alexeev <[email protected]> Signed-off-by: qishuai <[email protected]>
- Loading branch information
1 parent
5a4894e
commit 9360ad3
Showing
6 changed files
with
339 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
"""Make sure bad_words works. | ||
Run `pytest tests/samplers/test_no_bad_words.py`. | ||
""" | ||
from typing import List, Optional | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def _generate( | ||
model: LLM, | ||
prompt: str, | ||
num_prompt_tokens: int, | ||
temperature: float = 0, | ||
bad_words: Optional[List[str]] = None, | ||
) -> List[int]: | ||
sampling_params = SamplingParams( | ||
temperature=temperature, | ||
bad_words=bad_words, | ||
) | ||
|
||
# [([output_token_ids, ], [output_text, ]), ] | ||
output = model.generate([prompt], sampling_params=sampling_params) | ||
|
||
output_token_ids = output[0][0][0][num_prompt_tokens:] | ||
# [0] first (and only) request output | ||
# [0] token_ids (not text) | ||
# [0] first (and only) output completion | ||
|
||
return output_token_ids | ||
|
||
|
||
class TestOneTokenBadWord: | ||
MODEL = "TheBloke/Llama-2-7B-fp16" | ||
|
||
PROMPT = "Hi! How are" | ||
TARGET_TOKEN = "you" | ||
|
||
def setup_method(self, method): | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, | ||
add_prefix_space=True) | ||
|
||
self.num_prompt_tokens = len(self._encode(self.PROMPT)) | ||
self.target_token_id = self._encode(self.TARGET_TOKEN, | ||
add_special_tokens=False)[0] | ||
|
||
def test_one_token_bad_word(self, vllm_runner): | ||
with vllm_runner(self.MODEL) as llm: | ||
output_token_ids = self._generate(llm) | ||
assert output_token_ids[0] == self.target_token_id | ||
|
||
output_token_ids = self._generate(llm, | ||
bad_words=[self.TARGET_TOKEN]) | ||
assert self.target_token_id not in output_token_ids | ||
|
||
def _generate(self, | ||
model: LLM, | ||
bad_words: Optional[List[str]] = None) -> List[int]: | ||
return _generate( | ||
model=model, | ||
prompt=self.PROMPT, | ||
num_prompt_tokens=self.num_prompt_tokens, | ||
bad_words=bad_words, | ||
) | ||
|
||
def _encode(self, | ||
prompt: str, | ||
add_special_tokens: bool = True) -> List[int]: | ||
return self.tokenizer(prompt, | ||
add_special_tokens=add_special_tokens).input_ids | ||
|
||
|
||
class TestTwoTokenBadWord: | ||
# Another model (with a different tokenizer behaviour) | ||
MODEL = "openai-community/gpt2" | ||
|
||
PROMPT = "How old are you? I am 10" | ||
TARGET_TOKEN1 = "years" | ||
TARGET_TOKEN2 = "old" | ||
NEIGHBOUR_TOKEN2 = "older" | ||
|
||
def setup_method(self, method): | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, | ||
add_prefix_space=True) | ||
|
||
self.num_prompt_tokens = len(self._encode(self.PROMPT)) | ||
self.target_token_id1 = self._encode(self.TARGET_TOKEN1, | ||
add_special_tokens=False)[0] | ||
self.target_token_id2 = self._encode(self.TARGET_TOKEN2, | ||
add_special_tokens=False)[0] | ||
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2, | ||
add_special_tokens=False)[0] | ||
|
||
def test_two_token_bad_word(self, vllm_runner): | ||
with vllm_runner(self.MODEL) as llm: | ||
output_token_ids = self._generate(llm) | ||
assert output_token_ids[:2] == [ | ||
self.target_token_id1, self.target_token_id2 | ||
] | ||
|
||
output_token_ids = self._generate(llm, | ||
bad_words=[self.TARGET_TOKEN1]) | ||
assert self.target_token_id1 not in output_token_ids | ||
|
||
output_token_ids = self._generate(llm, | ||
bad_words=[self.TARGET_TOKEN2]) | ||
assert output_token_ids[0] == self.target_token_id1 | ||
assert self.target_token_id2 not in output_token_ids | ||
|
||
output_token_ids = self._generate( | ||
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}']) | ||
assert output_token_ids[0] == self.target_token_id1 | ||
assert output_token_ids[:2] != [ | ||
self.target_token_id1, self.target_token_id2 | ||
] | ||
assert not self._contains( | ||
output_token_ids, | ||
[self.target_token_id1, self.target_token_id2]) | ||
# Model dependent behaviour | ||
assert output_token_ids[:2] == [ | ||
self.target_token_id1, self.neighbour_token_id2 | ||
] | ||
|
||
output_token_ids = self._generate( | ||
llm, | ||
bad_words=[ | ||
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}', | ||
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}' | ||
]) | ||
assert output_token_ids[0] == self.target_token_id1 | ||
assert output_token_ids[:2] != [ | ||
self.target_token_id1, self.target_token_id2 | ||
] | ||
assert not self._contains( | ||
output_token_ids, | ||
[self.target_token_id1, self.target_token_id2]) | ||
assert output_token_ids[:2] != [ | ||
self.target_token_id1, self.neighbour_token_id2 | ||
] | ||
assert not self._contains( | ||
output_token_ids, | ||
[self.target_token_id1, self.neighbour_token_id2]) | ||
assert ((self.target_token_id2 in output_token_ids) | ||
or (self.neighbour_token_id2 in output_token_ids)) | ||
|
||
def _generate(self, | ||
model: LLM, | ||
bad_words: Optional[List[str]] = None) -> List[int]: | ||
return _generate( | ||
model=model, | ||
prompt=self.PROMPT, | ||
num_prompt_tokens=self.num_prompt_tokens, | ||
bad_words=bad_words, | ||
) | ||
|
||
@staticmethod | ||
def _contains(sequence: List[int], subsequence: List[int]) -> bool: | ||
searched = False | ||
|
||
for start in range(len(sequence)): | ||
end = start + len(subsequence) | ||
current_subsequence = sequence[start:end] | ||
|
||
if len(current_subsequence) < len(subsequence): | ||
continue | ||
|
||
searched = True | ||
|
||
assert len(current_subsequence) == len(subsequence) | ||
|
||
if current_subsequence == subsequence: | ||
return True | ||
|
||
assert searched, "All subsequences did not match in length..." | ||
|
||
return False | ||
|
||
def _encode(self, | ||
prompt: str, | ||
add_special_tokens: bool = True) -> List[int]: | ||
return self.tokenizer(prompt, | ||
add_special_tokens=add_special_tokens).input_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Callable, List, Tuple, Union | ||
|
||
import torch | ||
|
||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer | ||
|
||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], | ||
Callable[[List[int], List[int], torch.Tensor], | ||
torch.Tensor]] | ||
"""LogitsProcessor is a function that takes a list | ||
of previously generated tokens, the logits tensor | ||
for the next token and, optionally, prompt tokens as a | ||
first argument, and returns a modified tensor of logits | ||
to sample from.""" | ||
|
||
|
||
def get_bad_words_logits_processors( | ||
bad_words: List[str], | ||
tokenizer: AnyTokenizer) -> List[LogitsProcessor]: | ||
bad_words_ids: List[List[int]] = list() | ||
|
||
for bad_word in bad_words: | ||
# To prohibit words both at the beginning | ||
# and in the middle of text | ||
# (related to add_prefix_space tokenizer parameter) | ||
for add_prefix_space in [False, True]: | ||
prefix = " " if add_prefix_space else "" | ||
prompt = prefix + bad_word.lstrip() | ||
|
||
if isinstance(tokenizer, MistralTokenizer): | ||
# Mistral tokenizers should not add special tokens | ||
prompt_token_ids = tokenizer.encode(prompt=prompt) | ||
else: | ||
prompt_token_ids = tokenizer.encode(text=prompt, | ||
add_special_tokens=False) | ||
|
||
# If no space at the beginning | ||
# or if prefix space produces a new word token | ||
if (not add_prefix_space) or ( | ||
add_prefix_space | ||
and prompt_token_ids[0] != bad_words_ids[-1][0] | ||
and len(prompt_token_ids) == len(bad_words_ids[-1])): | ||
bad_words_ids.append(prompt_token_ids) | ||
|
||
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] | ||
|
||
|
||
class NoBadWordsLogitsProcessor: | ||
_SMALLEST_LOGIT = float("-inf") | ||
_NEUTRAL_LOGIT = 0.0 | ||
|
||
def __init__(self, bad_words_ids: List[List[int]]): | ||
self.bad_words_ids = bad_words_ids | ||
self.word_bias: torch.FloatTensor = None | ||
|
||
def __call__( | ||
self, | ||
past_tokens_ids: Union[List[int], Tuple[int]], | ||
logits: torch.FloatTensor, | ||
) -> torch.Tensor: | ||
if self.word_bias is None: | ||
self._init_word_bias(logits=logits) | ||
|
||
last_token_bias = torch.zeros_like(logits) | ||
|
||
for bad_word_ids in self.bad_words_ids: | ||
if len(bad_word_ids) == 1: # 1-token words already processed | ||
continue | ||
|
||
if len(bad_word_ids) > len(past_tokens_ids) + 1: | ||
continue | ||
|
||
prefix_length = len(bad_word_ids) - 1 | ||
last_token_id = bad_word_ids[-1] | ||
actual_prefix = past_tokens_ids[-prefix_length:] | ||
expected_prefix = bad_word_ids[:prefix_length] | ||
|
||
assert len(actual_prefix) == len(expected_prefix) | ||
|
||
is_match = tuple(actual_prefix) == tuple(expected_prefix) | ||
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match | ||
else self._NEUTRAL_LOGIT) | ||
|
||
logits = logits + self.word_bias + last_token_bias | ||
|
||
return logits | ||
|
||
def _init_word_bias(self, logits: torch.FloatTensor) -> None: | ||
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501 | ||
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py | ||
|
||
vocab_size = logits.shape[-1] | ||
|
||
self._check_token_ids_bounds(vocab_size=vocab_size) | ||
|
||
self.word_bias = torch.zeros((vocab_size, ), | ||
dtype=torch.float, | ||
device=logits.device) | ||
|
||
for bad_word_ids in self.bad_words_ids: | ||
if len(bad_word_ids) == 1: | ||
bad_word_id = bad_word_ids[-1] | ||
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT | ||
|
||
def _check_token_ids_bounds(self, vocab_size: int) -> None: | ||
invalid_token_ids = [] | ||
|
||
for bad_word_ids in self.bad_words_ids: | ||
for token_id in bad_word_ids: | ||
if token_id < 0 or token_id >= vocab_size: | ||
invalid_token_ids.append(token_id) | ||
|
||
if len(invalid_token_ids) > 0: | ||
raise ValueError( | ||
f"The model vocabulary size is {vocab_size}," | ||
f" but the following tokens" | ||
f" were specified as bad: {invalid_token_ids}." | ||
f" All token id values should be integers satisfying:" | ||
f" 0 <= token_id < {vocab_size}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.