diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py new file mode 100644 index 0000000000000..4190cf7cd7664 --- /dev/null +++ b/tests/samplers/test_no_bad_words.py @@ -0,0 +1,185 @@ +"""Make sure bad_words works. + +Run `pytest tests/samplers/test_no_bad_words.py`. + +""" +from typing import List, Optional + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams + + +def _generate( + model: LLM, + prompt: str, + num_prompt_tokens: int, + temperature: float = 0, + bad_words: Optional[List[str]] = None, +) -> List[int]: + sampling_params = SamplingParams( + temperature=temperature, + bad_words=bad_words, + ) + + # [([output_token_ids, ], [output_text, ]), ] + output = model.generate([prompt], sampling_params=sampling_params) + + output_token_ids = output[0][0][0][num_prompt_tokens:] + # [0] first (and only) request output + # [0] token_ids (not text) + # [0] first (and only) output completion + + return output_token_ids + + +class TestOneTokenBadWord: + MODEL = "TheBloke/Llama-2-7B-fp16" + + PROMPT = "Hi! How are" + TARGET_TOKEN = "you" + + def setup_method(self, method): + self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, + add_prefix_space=True) + + self.num_prompt_tokens = len(self._encode(self.PROMPT)) + self.target_token_id = self._encode(self.TARGET_TOKEN, + add_special_tokens=False)[0] + + def test_one_token_bad_word(self, vllm_runner): + with vllm_runner(self.MODEL) as llm: + output_token_ids = self._generate(llm) + assert output_token_ids[0] == self.target_token_id + + output_token_ids = self._generate(llm, + bad_words=[self.TARGET_TOKEN]) + assert self.target_token_id not in output_token_ids + + def _generate(self, + model: LLM, + bad_words: Optional[List[str]] = None) -> List[int]: + return _generate( + model=model, + prompt=self.PROMPT, + num_prompt_tokens=self.num_prompt_tokens, + bad_words=bad_words, + ) + + def _encode(self, + prompt: str, + add_special_tokens: bool = True) -> List[int]: + return self.tokenizer(prompt, + add_special_tokens=add_special_tokens).input_ids + + +class TestTwoTokenBadWord: + # Another model (with a different tokenizer behaviour) + MODEL = "openai-community/gpt2" + + PROMPT = "How old are you? I am 10" + TARGET_TOKEN1 = "years" + TARGET_TOKEN2 = "old" + NEIGHBOUR_TOKEN2 = "older" + + def setup_method(self, method): + self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, + add_prefix_space=True) + + self.num_prompt_tokens = len(self._encode(self.PROMPT)) + self.target_token_id1 = self._encode(self.TARGET_TOKEN1, + add_special_tokens=False)[0] + self.target_token_id2 = self._encode(self.TARGET_TOKEN2, + add_special_tokens=False)[0] + self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2, + add_special_tokens=False)[0] + + def test_two_token_bad_word(self, vllm_runner): + with vllm_runner(self.MODEL) as llm: + output_token_ids = self._generate(llm) + assert output_token_ids[:2] == [ + self.target_token_id1, self.target_token_id2 + ] + + output_token_ids = self._generate(llm, + bad_words=[self.TARGET_TOKEN1]) + assert self.target_token_id1 not in output_token_ids + + output_token_ids = self._generate(llm, + bad_words=[self.TARGET_TOKEN2]) + assert output_token_ids[0] == self.target_token_id1 + assert self.target_token_id2 not in output_token_ids + + output_token_ids = self._generate( + llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}']) + assert output_token_ids[0] == self.target_token_id1 + assert output_token_ids[:2] != [ + self.target_token_id1, self.target_token_id2 + ] + assert not self._contains( + output_token_ids, + [self.target_token_id1, self.target_token_id2]) + # Model dependent behaviour + assert output_token_ids[:2] == [ + self.target_token_id1, self.neighbour_token_id2 + ] + + output_token_ids = self._generate( + llm, + bad_words=[ + f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}', + f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}' + ]) + assert output_token_ids[0] == self.target_token_id1 + assert output_token_ids[:2] != [ + self.target_token_id1, self.target_token_id2 + ] + assert not self._contains( + output_token_ids, + [self.target_token_id1, self.target_token_id2]) + assert output_token_ids[:2] != [ + self.target_token_id1, self.neighbour_token_id2 + ] + assert not self._contains( + output_token_ids, + [self.target_token_id1, self.neighbour_token_id2]) + assert ((self.target_token_id2 in output_token_ids) + or (self.neighbour_token_id2 in output_token_ids)) + + def _generate(self, + model: LLM, + bad_words: Optional[List[str]] = None) -> List[int]: + return _generate( + model=model, + prompt=self.PROMPT, + num_prompt_tokens=self.num_prompt_tokens, + bad_words=bad_words, + ) + + @staticmethod + def _contains(sequence: List[int], subsequence: List[int]) -> bool: + searched = False + + for start in range(len(sequence)): + end = start + len(subsequence) + current_subsequence = sequence[start:end] + + if len(current_subsequence) < len(subsequence): + continue + + searched = True + + assert len(current_subsequence) == len(subsequence) + + if current_subsequence == subsequence: + return True + + assert searched, "All subsequences did not match in length..." + + return False + + def _encode(self, + prompt: str, + add_special_tokens: bool = True) -> List[int]: + return self.tokenizer(prompt, + add_special_tokens=add_special_tokens).input_ids diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1dd0f097c74ff..ede77f04b1db9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,7 +26,8 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group -from vllm.entrypoints.openai.logits_processors import get_logits_processors +from vllm.entrypoints.openai.logits_processors import ( + get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster @@ -34,6 +35,7 @@ EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger +from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( get_local_guided_decoding_logits_processor) @@ -1963,6 +1965,7 @@ def _build_logits_processors( logits_processors field. Returns the modified sampling params.""" logits_processors = [] + if (guided_decoding := sampling_params.guided_decoding) is not None: logger.debug( @@ -1984,7 +1987,7 @@ def _build_logits_processors( if (sampling_params.logit_bias or sampling_params.allowed_token_ids): tokenizer = self.get_tokenizer(lora_request=lora_request) - processors = get_logits_processors( + processors = get_openai_logits_processors( logit_bias=sampling_params.logit_bias, allowed_token_ids=sampling_params.allowed_token_ids, tokenizer=tokenizer) @@ -1994,6 +1997,12 @@ def _build_logits_processors( sampling_params.logit_bias = None sampling_params.allowed_token_ids = None + if len(sampling_params.bad_words) > 0: + tokenizer = self.get_tokenizer(lora_request) + processors = get_bad_words_logits_processors( + bad_words=sampling_params.bad_words, tokenizer=tokenizer) + logits_processors.extend(processors) + if logits_processors: if sampling_params.logits_processors is None: sampling_params.logits_processors = logits_processors diff --git a/vllm/logits_process.py b/vllm/logits_process.py new file mode 100644 index 0000000000000..7716ccd27e253 --- /dev/null +++ b/vllm/logits_process.py @@ -0,0 +1,119 @@ +from typing import Callable, List, Tuple, Union + +import torch + +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[List[int], List[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" + + +def get_bad_words_logits_processors( + bad_words: List[str], + tokenizer: AnyTokenizer) -> List[LogitsProcessor]: + bad_words_ids: List[List[int]] = list() + + for bad_word in bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + + if isinstance(tokenizer, MistralTokenizer): + # Mistral tokenizers should not add special tokens + prompt_token_ids = tokenizer.encode(prompt=prompt) + else: + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) + + # If no space at the beginning + # or if prefix space produces a new word token + if (not add_prefix_space) or ( + add_prefix_space + and prompt_token_ids[0] != bad_words_ids[-1][0] + and len(prompt_token_ids) == len(bad_words_ids[-1])): + bad_words_ids.append(prompt_token_ids) + + return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] + + +class NoBadWordsLogitsProcessor: + _SMALLEST_LOGIT = float("-inf") + _NEUTRAL_LOGIT = 0.0 + + def __init__(self, bad_words_ids: List[List[int]]): + self.bad_words_ids = bad_words_ids + self.word_bias: torch.FloatTensor = None + + def __call__( + self, + past_tokens_ids: Union[List[int], Tuple[int]], + logits: torch.FloatTensor, + ) -> torch.Tensor: + if self.word_bias is None: + self._init_word_bias(logits=logits) + + last_token_bias = torch.zeros_like(logits) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: # 1-token words already processed + continue + + if len(bad_word_ids) > len(past_tokens_ids) + 1: + continue + + prefix_length = len(bad_word_ids) - 1 + last_token_id = bad_word_ids[-1] + actual_prefix = past_tokens_ids[-prefix_length:] + expected_prefix = bad_word_ids[:prefix_length] + + assert len(actual_prefix) == len(expected_prefix) + + is_match = tuple(actual_prefix) == tuple(expected_prefix) + last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match + else self._NEUTRAL_LOGIT) + + logits = logits + self.word_bias + last_token_bias + + return logits + + def _init_word_bias(self, logits: torch.FloatTensor) -> None: + # Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501 + # from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py + + vocab_size = logits.shape[-1] + + self._check_token_ids_bounds(vocab_size=vocab_size) + + self.word_bias = torch.zeros((vocab_size, ), + dtype=torch.float, + device=logits.device) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: + bad_word_id = bad_word_ids[-1] + self.word_bias[bad_word_id] = self._SMALLEST_LOGIT + + def _check_token_ids_bounds(self, vocab_size: int) -> None: + invalid_token_ids = [] + + for bad_word_ids in self.bad_words_ids: + for token_id in bad_word_ids: + if token_id < 0 or token_id >= vocab_size: + invalid_token_ids.append(token_id) + + if len(invalid_token_ids) > 0: + raise ValueError( + f"The model vocabulary size is {vocab_size}," + f" but the following tokens" + f" were specified as bad: {invalid_token_ids}." + f" All token id values should be integers satisfying:" + f" 0 <= token_id < {vocab_size}.") diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 368436aa14613..d7b67425fcbc0 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,6 +1,7 @@ from typing import Optional -from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor +from vllm.logits_process import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams async def get_guided_decoding_logits_processor( diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index cf2162ed7720d..a17e75a80300f 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -9,7 +9,8 @@ build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) from transformers import PreTrainedTokenizerBase -from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor +from vllm.logits_process import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams def get_local_lm_format_enforcer_guided_decoding_logits_processor( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9993cec13d649..bac32c991a0e3 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -3,14 +3,14 @@ from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union import msgspec -import torch from pydantic import BaseModel from typing_extensions import Annotated from vllm.logger import init_logger +from vllm.logits_process import LogitsProcessor logger = init_logger(__name__) @@ -24,16 +24,6 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 -LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], - Callable[[List[int], List[int], torch.Tensor], - torch.Tensor]] -"""LogitsProcessor is a function that takes a list -of previously generated tokens, the logits tensor -for the next token and, optionally, prompt tokens as a -first argument, and returns a modified tensor of logits -to sample from.""" - - # maybe make msgspec? @dataclass class GuidedDecodingParams: @@ -139,6 +129,10 @@ class SamplingParams( stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens. + bad_words: List of words that are not allowed to be generated. + More precisely, only the last token of a corresponding + token sequence is not allowed when the next generated token + can complete the sequence. include_stop_str_in_output: Whether to include the stop strings in output text. Defaults to False. ignore_eos: Whether to ignore the EOS token and continue generating @@ -186,6 +180,7 @@ class SamplingParams( seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[List[int]] = None + bad_words: Optional[List[str]] = None ignore_eos: bool = False max_tokens: Optional[int] = 16 min_tokens: int = 0 @@ -228,6 +223,7 @@ def from_optional( seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + bad_words: Optional[List[str]] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -267,6 +263,7 @@ def from_optional( seed=seed, stop=stop, stop_token_ids=stop_token_ids, + bad_words=bad_words, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -298,26 +295,36 @@ def __post_init__(self) -> None: f"got n={self.n} and best_of={self.best_of}.") self._real_n = self.n self.n = self.best_of + if 0 < self.temperature < _MAX_TEMP: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", self.temperature, _MAX_TEMP, _MAX_TEMP) self.temperature = max(self.temperature, _MAX_TEMP) + if self.seed == -1: self.seed = None else: self.seed = self.seed + if self.stop is None: self.stop = [] elif isinstance(self.stop, str): self.stop = [self.stop] else: self.stop = list(self.stop) + if self.stop_token_ids is None: self.stop_token_ids = [] else: self.stop_token_ids = list(self.stop_token_ids) + + if self.bad_words is None: + self.bad_words = [] + else: + self.bad_words = list(self.bad_words) + self.logprobs = 1 if self.logprobs is True else self.logprobs self.prompt_logprobs = (1 if self.prompt_logprobs is True else self.prompt_logprobs) @@ -468,6 +475,7 @@ def __repr__(self) -> str: f"seed={self.seed}, " f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " + f"bad_words={self.bad_words}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, "