Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add bad_words_ids sampling parameter #5986

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

Alvant
Copy link

@Alvant Alvant commented Jun 29, 2024

FIX #986


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@@ -28,6 +28,77 @@ class SamplingType(IntEnum):
to sample from."""


class NoBadWordsLogitsProcessor:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is the right file for the class. Still I thought it could be placed here near the LogitsProcessor type (the one just above). But if there is a better place for the processor class, I will be ready to move it)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that if we include this it should go in a different file.

if len(bad_word_ids) == 1: # 1-token words already processed
continue

if len(bad_word_ids) > len(past_tokens_ids) + 1:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This differs from the original inequality here: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L1131

if len(sequence_ids) > input_ids.shape[1]:  # the sequence is longer than the context, ignore
    continue

I may be wrong, but it seemed we should allow for one token to be generated (so + 1).

@njhill
Copy link
Member

njhill commented Jul 1, 2024

I'm unsure whether or not it makes sense to support this. I know it's an option in transfomers but it was added very early on and the implementation seems limited/clunky to me. Wouldn't it make more sense for the bad words to be a list of strings rather than token sequences?

@Alvant
Copy link
Author

Alvant commented Jul 2, 2024

I'm unsure whether or not it makes sense to support this. I know it's an option in transfomers but it was added very early on and the implementation seems limited/clunky to me. Wouldn't it make more sense for the bad words to be a list of strings rather than token sequences?

@njhill

Yes, I absolutely agree that this "list of lists of token ids" structure is not very friendly and easy to use 😅 Sure, list of strings will be more convenient. I believe I can change that. Just wanted to clarify some questions before actually making some changes.

If we make bad words as list of strings (and call it, for example, just bad_words), we will lose compatibility with transformers interface. If people come to vLLM after transformers, they might already now about bad_words_ids parameter and how to use it. And they will look for this parameter in SamplingParams attributes.

So, the main and only question is actually the following — should we keep transformers' "clumsy" bad_words_ids option together with more friendly bad_words? should we support list of lists of ids structure or only list of strings?

Oh, and one more point.

Currently, SamplingParams has stop_token_ids parameter (list of token ids). If we make bad_words as list of strings, would it not bring some "heterogeneity"? (one parameter is about token ids, another one is about strings) If bad_words_ids is about token ids, then it is quite consistent with stop_token_ids.

Hmm, just noticed, SamplingParams actually has another parameter — stop, which is like stop_token_ids, but consisting of strings...

Ok, I agree that introducing bad_words instead of bad_words_ids is overall a good idea) At the moment, I do not see a way to change the behavior easily (it seems unlikely that a logits_preprocessor will do the thing as we will not know token ids beforehand), but I will definitely look into this in the coming days) I am just afraid that, after changes, it will be a completely different PR 🙂

@njhill
Copy link
Member

njhill commented Jul 2, 2024

My view on this kind of thing is to collect some concrete requirements / use cases and base on that. I.e. avoid adding things with hypothetical benefit. Would be good to see some explicit examples of how/where this functionality is used, and that should then also inform what kind of thing makes the most sense w.r.t. the various options being discussed.

@Alvant
Copy link
Author

Alvant commented Jul 4, 2024

I looked through some examples and use cases.

First, it seems that bad_words_ids (lists of token sequences) is already a requested feature. Apart from the issue which this PR is linked to, I also found this already closed issue in vLLM (where the proposed solution works only for single tokens). In TensorRT-LLM, they already have this feature, also as lists of token ids. In OpenAI request parameters, there is one called "logit_bias" which is like a generalization of bad_words_ids, and also works on a token level.

However, seems like it would be indeed better to have a list of words and not token sequences. Because people always have to use tokenizers in order to get token ids. This is just stated in the documentation on the transformers docs site. Usually, these tokenizer things lead to questions (not very friendly "low level" functionality).

I also found a few mentions of bad words list (not token ids) functionality in other repositories. For example, in TensorRT-LLM Backend, there is already an option to provide a list of bad words as strings. In outlines, there are plans to add this feature.

So, I updated the PR: now there is bad_words parameter, not bad_words_ids.

However, to some extent, this complicated the implementation. So that I am currently not 100% sure that it is just a "Frontend" related feature 😅

@njhill could you please take a look and share your thoughts on the updated code?

Points which are to be resolved (if the general idea would seem OK, I will fix this):

  • NoBadWordsLogitsProcessor is still in sampling_params.py file
  • AsyncLLMEngine would need corresponding updates

P.S.

For the history, this is the PR which added bad_words_ids to the transformers library: huggingface/transformers#3367.

@Alvant
Copy link
Author

Alvant commented Jul 4, 2024

Forgot to add, in vLLM, there is already something like bad words ids thing: logit_bias_logits_processor, which is added for the compatibility with OpenAI request params.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

It seems that SamplingParams doesnt support the bad_words_ids parameter when generating
2 participants