-
Notifications
You must be signed in to change notification settings - Fork 485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement prompt/generation alignment #531
base: main
Are you sure you want to change the base?
Conversation
I think this is the right general direction.
Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens. |
Making up a fake example. My prompt is "Good mor". Let's say there's a token for "mor" and it's the last one of the prompt. We would want token alignment to replace "mor" with "morning". However, if the token "ning" by itself does not exist, then there's nothing in the I was looking at creating I was then thinking that a solution could be to create at initialization a mapping that contains both information about characters and about tokens (so we would have some states with no tokens leading to them that would be used for the token alignement) |
How about looping over the entire vocabulary and store the tokens that accept Haven't taken the time to think about the constrained case yet. |
I had not realized that I could walk the |
Yes I think that's the right approach. There's some stuff to figure out in terms of design, but otherwise looks good. |
01bfc21
to
4aa74f2
Compare
I'll write unit tests next if you think having those separate functions is the right design |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made several comments on the overall design, but nothing that would dramatically affect your implementation. You can start implementing tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi there, I am just a user here who is looking forward to this change. However, I noticed that there is an error if the model is running on a GPU. I think it could be fixed by passing the device in these two statements here (at least, this fixes it for me)
We're getting really close. There are a few design changes remaining, and mostly we should have comprehensive tests before merging. |
6bb90f8
to
29853ec
Compare
I rebased your branch on |
That's great news! Please let me know if you run into any issues or have any questions about You probably want to branch from #966 since it has fixes to the logits processors and a more detailed docstring. |
To make sure I understand the wider context, the plan is to eventually remove |
Indeed |
29853ec
to
b74dc8e
Compare
I rebased on your branch and modified my initial commit @lapp0 |
Could you rebase on |
b74dc8e
to
b3f415e
Compare
This issue is causing problems for the PR. If we don't have an explanation/solution for it, we would have to modify the logic related to |
@RobinPicard per my comment in the linked issue, it appears that transformers beam search submits an unused sequence group to logits processors during the final generation step. Is this still an issue if it only occurs on the last step and it's not actually part of generation? Please let me know how I can help. |
b3f415e
to
4cfa9e1
Compare
It's fine if it only happens at the final generation step, I simply added a |
Glad that it's not blocking, please let me know if you run into any other issues or have any questions! |
I don't have more questions, I would be interested by a review of the PR though! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there might be a better way to implement this
Could you give me your thoughts on the following design:
If alignment is enabled, in
-
SequenceGeneratorAdapter
: Update the prompt token IDs in accordance with alignment, trackremovedchars
-
guide.py
: CreateTokenHealingGuide
which enforces generation starting withremovedchars
, and passes the "unhealed" tokens to the original Guide.
-
SequenceGeneratorAdapter
: left-truncatelen(removedchars)
once generation is complete
This would be forward-compatible with CFG, and minimize the the context of the token healing to just SequenceGeneratorAdapter
and a TokenHealingGuide
.
outlines/fsm/guide.py
Outdated
"""Update the states_to_token_maps and return the aligned prompt""" | ||
token_ids, _ = tokenizer.encode(prompt) | ||
# possible return types of tokenizers include list, 1d Tensor and 2d Tensor | ||
if not isinstance(token_ids, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should ensure tokenizers have the same behavior rather than handling edge cases here.
Related: #936
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made a commit for the encode/decode methods of the tokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
Could you please add a test which verifies correct functionality exists within tests/generate/test_generate.py
against all outlines.models
?
@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_token_healing_regex(request, model_fixture, pattern):
model = request.getfixturevalue(model_fixture)
# similar to test_generate_regex, but verifies a match of the pattern is generated with token healing as well
outlines/fsm/guide.py
Outdated
@@ -334,6 +418,10 @@ def __init__(self, cfg_string: str, tokenizer): | |||
self.start_state = 0 | |||
self.final_state = -1 | |||
|
|||
def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str: | |||
"""Not applicable to this type of Guide""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't alignment work for CFGGuide
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no states_to_token_map
available before the start of the generation and there's the problem that the crossing tokens are not valid in the context free grammar. It's probably possible to implement it for CFGGuide
, I just don't know how to do it at the moment.
def strip_completions( | ||
self, | ||
completions, | ||
prompts: Union[str, List[str]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SequenceGenerator
is only used by exllamav2
at this point, and will no longer be used after #1010
Can we move this logic to SequenceGeneratorAdapter
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already a method of SequenceGeneratorAdapter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that none of the logic of prompt/generation alignment should exist in SequenceGenerator
I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt? Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better? |
…eGeneratorAdapter
4cfa9e1
to
4f640c4
Compare
Here's roughly what I'm thinking In
In
This design requires constructing a Guide for each prompt, however this is necessary because the behavior of the Guide varies prompt-to-prompt. Details surrounding how to pre-process and manage state for Please let me know if this makes sense, or if I'm missing something. |
I think I understand the general idea. The main issue I have concerns
I don't get what the unaligned token means in this context. |
Sorry I wasn't clear here. For example, if the prompt is "hello wo" and we truncate it to "hello" for alignment, and the model generates the token " world" as the next output, "rld" is what needs to be passed to the child guide to continue processing. This allows for alignment to be compatible with any guide with minimal changes. |
Ah I see, but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the |
Sorry for the delayed response.
If I'm understanding the problem correctly, to mitigate this we need to determine the "longest common prompt prefix". This will allow any legal token to be generated as a "pseudo-token".
Can we precompute this when the guide is constructed? |
[updated 2024-06-28]
The aim of this PR to implement prompt token alignment
The idea is to modify the
states_to_token_maps
of theGuide
to include in it the characters of some of the last tokens of the prompt that could be replaced by a different token that contains the same set of characters plus characters for the generation (a crossing token).To do so, when receiving the prompts of the user (so after the
OutlinesLogitsProcessor
has already been initialized with its FSM), we copy the FSM as many as their are prompts and we apply to each of them prompt token alignment (as the modification of thestates_to_token_maps depends
on the content of each prompt).At the end of the process, we modify the generated sequences to remove the characters at the beginning that correspond to the ends of the user prompts