Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prompt/generation alignment #531

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

RobinPicard
Copy link
Contributor

@RobinPicard RobinPicard commented Jan 11, 2024

[updated 2024-06-28]

The aim of this PR to implement prompt token alignment

The idea is to modify the states_to_token_maps of the Guide to include in it the characters of some of the last tokens of the prompt that could be replaced by a different token that contains the same set of characters plus characters for the generation (a crossing token).

To do so, when receiving the prompts of the user (so after the OutlinesLogitsProcessor has already been initialized with its FSM), we copy the FSM as many as their are prompts and we apply to each of them prompt token alignment (as the modification of the states_to_token_maps depends on the content of each prompt).

At the end of the process, we modify the generated sequences to remove the characters at the beginning that correspond to the ends of the user prompts

@rlouf rlouf added enhancement transformers Linked to the `transformers` integration correctness Everything related to the generation correctness labels Jan 16, 2024
@rlouf
Copy link
Member

rlouf commented Jan 27, 2024

This is not intended to be merged, I was just wondering whether you think this is a promising direction to look into

I think this is the right general direction.

  • The case in which the text after the "boundary" of a token matching the end of the prompt does not exist in the vocabulary by itself is not covered

Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.

@rlouf rlouf linked an issue Jan 27, 2024 that may be closed by this pull request
@RobinPicard
Copy link
Contributor Author

Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.

Making up a fake example. My prompt is "Good mor". Let's say there's a token for "mor" and it's the last one of the prompt. We would want token alignment to replace "mor" with "morning". However, if the token "ning" by itself does not exist, then there's nothing in the states_to_token_maps that correspond to it as, at this point, the character-based FSM that would allow to generate "ning" has already been turned into a token-based mapping.

I was looking at creating states_to_token_maps only after the call is made (and the FSM is updated) but that would add too much overhead.

I was then thinking that a solution could be to create at initialization a mapping that contains both information about characters and about tokens (so we would have some states with no tokens leading to them that would be used for the token alignement)

@rlouf
Copy link
Member

rlouf commented Jan 27, 2024

How about looping over the entire vocabulary and store the tokens that accept mor as a prefix. Then, in the unconstrained case the first state of the FSM would have transitions to the overlapping tokens only?

Haven't taken the time to think about the constrained case yet.

@RobinPicard
Copy link
Contributor Author

I had not realized that I could walk the states_to_token_maps character by character for the postfix part of the crossing tokens in the constrained case. I think it works with almost no additional overhead like that. Let me know if you think it's fine and I'll update the tests afterward

@RobinPicard RobinPicard marked this pull request as ready for review January 30, 2024 09:08
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Feb 10, 2024

Yes I think that's the right approach. There's some stuff to figure out in terms of design, but otherwise looks good.

@rlouf rlouf changed the title Proposition implementation token alignment Implement prompt/generation alignment Feb 11, 2024
@RobinPicard
Copy link
Contributor Author

I'll write unit tests next if you think having those separate functions is the right design

outlines/fsm/fsm.py Outdated Show resolved Hide resolved
outlines/generate/api.py Outdated Show resolved Hide resolved
outlines/generate/api.py Outdated Show resolved Hide resolved
outlines/fsm/fsm.py Outdated Show resolved Hide resolved
Copy link
Member

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made several comments on the overall design, but nothing that would dramatically affect your implementation. You can start implementing tests.

Copy link
Contributor

@shawnz shawnz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there, I am just a user here who is looking forward to this change. However, I noticed that there is an error if the model is running on a GPU. I think it could be fixed by passing the device in these two statements here (at least, this fixes it for me)

outlines/generate/api.py Outdated Show resolved Hide resolved
outlines/generate/api.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Mar 1, 2024

We're getting really close. There are a few design changes remaining, and mostly we should have comprehensive tests before merging.

@rlouf
Copy link
Member

rlouf commented Mar 11, 2024

I rebased your branch on main after a big refactor of the FSM interface. I will take a closer look this week.

@lapp0
Copy link
Contributor

lapp0 commented Jun 20, 2024

@RobinPicard are you still interested in implementing this?

I can look at adapting it to the change made this week end

That's great news!

Please let me know if you run into any issues or have any questions about OutlinesLogitsProcessor.

You probably want to branch from #966 since it has fixes to the logits processors and a more detailed docstring.

@RobinPicard
Copy link
Contributor Author

RobinPicard commented Jun 25, 2024

To make sure I understand the wider context, the plan is to eventually remove SequenceGenerator and only use SequenceGeneratorAdapter @lapp0, right? If so, should we implement it for both of those or only the latter?

@rlouf
Copy link
Member

rlouf commented Jun 27, 2024

Indeed

@RobinPicard
Copy link
Contributor Author

I rebased on your branch and modified my initial commit @lapp0

@rlouf
Copy link
Member

rlouf commented Jul 16, 2024

Could you rebase on main now that #966 was merged?

@RobinPicard
Copy link
Contributor Author

This issue is causing problems for the PR. If we don't have an explanation/solution for it, we would have to modify the logic related to FSMLogitsProcessor._fsm_states. I don't know yet what we could replace it with if the order of sequences is not maintained + the values of previous tokens can change though.

@lapp0
Copy link
Contributor

lapp0 commented Jul 19, 2024

@RobinPicard per my comment in the linked issue, it appears that transformers beam search submits an unused sequence group to logits processors during the final generation step. Is this still an issue if it only occurs on the last step and it's not actually part of generation?

Please let me know how I can help.

@RobinPicard
Copy link
Contributor Author

It's fine if it only happens at the final generation step, I simply added a try:... except KeyError:... block

@lapp0
Copy link
Contributor

lapp0 commented Jul 19, 2024

Glad that it's not blocking, please let me know if you run into any other issues or have any questions!

@RobinPicard
Copy link
Contributor Author

I don't have more questions, I would be interested by a review of the PR though!

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a better way to implement this

Could you give me your thoughts on the following design:

If alignment is enabled, in

    1. SequenceGeneratorAdapter: Update the prompt token IDs in accordance with alignment, track removedchars
    1. guide.py: Create TokenHealingGuide which enforces generation starting with removedchars, and passes the "unhealed" tokens to the original Guide.
    1. SequenceGeneratorAdapter: left-truncate len(removedchars) once generation is complete

This would be forward-compatible with CFG, and minimize the the context of the token healing to just SequenceGeneratorAdapter and a TokenHealingGuide.

outlines/fsm/guide.py Outdated Show resolved Hide resolved
outlines/fsm/guide.py Outdated Show resolved Hide resolved
"""Update the states_to_token_maps and return the aligned prompt"""
token_ids, _ = tokenizer.encode(prompt)
# possible return types of tokenizers include list, 1d Tensor and 2d Tensor
if not isinstance(token_ids, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ensure tokenizers have the same behavior rather than handling edge cases here.

Related: #936

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made a commit for the encode/decode methods of the tokenizer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Could you please add a test which verifies correct functionality exists within tests/generate/test_generate.py against all outlines.models?

@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_token_healing_regex(request, model_fixture, pattern):
    model = request.getfixturevalue(model_fixture)
    # similar to test_generate_regex, but verifies a match of the pattern is generated with token healing as well

outlines/fsm/guide.py Outdated Show resolved Hide resolved
outlines/fsm/guide.py Outdated Show resolved Hide resolved
@@ -334,6 +418,10 @@ def __init__(self, cfg_string: str, tokenizer):
self.start_state = 0
self.final_state = -1

def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str:
"""Not applicable to this type of Guide"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't alignment work for CFGGuide?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no states_to_token_map available before the start of the generation and there's the problem that the crossing tokens are not valid in the context free grammar. It's probably possible to implement it for CFGGuide, I just don't know how to do it at the moment.

outlines/fsm/guide.py Show resolved Hide resolved
def strip_completions(
self,
completions,
prompts: Union[str, List[str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SequenceGenerator is only used by exllamav2 at this point, and will no longer be used after #1010

Can we move this logic to SequenceGeneratorAdapter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already a method of SequenceGeneratorAdapter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that none of the logic of prompt/generation alignment should exist in SequenceGenerator

outlines/integrations/llamacpp.py Outdated Show resolved Hide resolved
@RobinPicard
Copy link
Contributor Author

I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?

Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?

@lapp0
Copy link
Contributor

lapp0 commented Jul 26, 2024

I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?

Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?

Here's roughly what I'm thinking

In SequenceGeneratorAdapter

  • set the prompt as alignment_guide.prompt_prefix
  • then call model.generate,
  • then left-truncate generations

In AlignmentGuide

  • Constructor: AlignmentGuide(prompt, tokenizer, child_guide=None)
    • Let self.start_generation_tokens be the set of legal starting tokens for child_guide (or all tokens if no guide)
    • Tokenize (prompt, token) for token in start_generation_tokens and determine the "longest common self.prompt_prefix" and the set of legal self.prompt_suffix_ids
    • self.initial_state = AlignmentGuideState(legal_paths=legal_start_generation_tokens, child_guide_state=child_guide.initial_state)
  • For get_next_instruction(state)
    • If legal_paths None, alignment is complete. Defer to the child_guide.
    • If legal_paths is not None, tokens = [path[0] for path in state.legal_paths]
  • For get_next_state(state, token_id)
    • If legal paths: filter state.legal_paths such that it only includes those starting with token_id
      • If generating the final token in a path, pass the unaligned token to the child_guide, return `AlignmentGuideState(legal_paths=none, child_guide_state=new_child_guide_state)
  • If no legal_paths is None: Use the child guide to update child_guide_state

This design requires constructing a Guide for each prompt, however this is necessary because the behavior of the Guide varies prompt-to-prompt.

Details surrounding how to pre-process and manage state for AlignmentGuide may vary, this is just a rough outline.

Please let me know if this makes sense, or if I'm missing something.

@RobinPicard
Copy link
Contributor Author

I think I understand the general idea. The main issue I have concerns

  • If generating the final token in a path, pass the unaligned token to the child_guide

I don't get what the unaligned token means in this context.

@lapp0
Copy link
Contributor

lapp0 commented Aug 2, 2024

I don't get what the unaligned token means in this context.

Sorry I wasn't clear here.

For example, if the prompt is "hello wo" and we truncate it to "hello" for alignment, and the model generates the token " world" as the next output, "rld" is what needs to be passed to the child guide to continue processing.

This allows for alignment to be compatible with any guide with minimal changes.

@RobinPicard
Copy link
Contributor Author

Ah I see, but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map. Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?

@lapp0
Copy link
Contributor

lapp0 commented Aug 12, 2024

Sorry for the delayed response.

but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map.

If I'm understanding the problem correctly, to mitigate this we need to determine the "longest common prompt prefix". This will allow any legal token to be generated as a "pseudo-token".

Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?

Can we precompute this when the guide is constructed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
correctness Everything related to the generation correctness enhancement transformers Linked to the `transformers` integration
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement prompt/generation alignment
4 participants