Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prompt/generation alignment #531

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 230 additions & 25 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -69,6 +71,9 @@ class Guide(Protocol):

"""

start_state: int = 0
final_state: int = -1

def get_next_instruction(self, state: int) -> Instruction:
...

Expand All @@ -82,11 +87,39 @@ def copy(self) -> "Guide":
...


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""
class TokenHealerMixin:
"""Class used to add the token align feature to a Guide"""

final_state = 1
start_state = 0
states_to_token_maps: Dict[int, Dict[int, int]]
tokenizer: "Tokenizer"

def align_prompt_tokens(self, prompt: str) -> str:
"""Update the states_to_token_maps and return the aligned prompt"""
token_ids, _ = self.tokenizer.encode(prompt)
(
aligned_token_ids,
aligned_states_to_token_maps,
) = align_tokens_states_to_token_maps(
token_ids.tolist()[0],
self.tokenizer.vocabulary,
deepcopy(self.states_to_token_maps),
)
aligned_prompt = self.tokenizer.decode([aligned_token_ids])[0]
# some models do not accept an empty string as a prompt
# if token alignement would remove all tokens, do not apply it
if not aligned_prompt:
return prompt
self.states_to_token_maps = aligned_states_to_token_maps
if hasattr(self, "_cache_state_to_token_tensor"):
self._cache_state_to_token_tensor()
# remove leading whitespace if added by the tokenizer
if aligned_prompt[0] == " " and prompt[0] != " ":
return aligned_prompt[1:]
return aligned_prompt


class StopAtEOSGuide(Guide, TokenHealerMixin):
"""Guide to generate tokens until the EOS token has been generated."""

def __init__(self, tokenizer: "Tokenizer"):
"""Initialize the generation guide.
Expand All @@ -95,25 +128,37 @@ def __init__(self, tokenizer: "Tokenizer"):
The logit generator used to generate the next token.

"""
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.tokenizer = tokenizer
self.states_to_token_maps = self.create_states_to_tokens_map()

def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
"""Create the states_to_tokens_map. All tokens lead to the starting
state, except for the eos_token that leads to the final state."""
return {
self.start_state: {
token_id: self.start_state
if token_id != self.tokenizer.eos_token_id
else self.final_state
for token_id in self.tokenizer.vocabulary.values()
}
}

def get_next_instruction(self, state: int) -> Instruction:
if self.is_final_state(state):
return Write([self.eos_token_id])
return Generate(None)
return Write([self.tokenizer.eos_token_id])
return Generate(list(self.states_to_token_maps[state].keys()))

def get_next_state(self, state: int, token_id: int) -> int:
if token_id == self.eos_token_id or state == self.final_state:
if self.is_final_state(state):
return self.final_state

return self.start_state
return self.states_to_token_maps[state][token_id]

def is_final_state(self, state: int):
return state == self.final_state

def copy(self):
return self
return copy(self)


@cache()
Expand Down Expand Up @@ -171,20 +216,20 @@ def create_states_mapping(
return states_to_token_maps, empty_token_ids, regex_fsm.finals


class RegexGuide(Guide):
class RegexGuide(Guide, TokenHealerMixin):
"""Guide to generate text in the language of a regular expression."""

initial_state = 0
states_to_token_mask: Dict[int, torch.Tensor]

def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
self.tokenizer = tokenizer
(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}
self._cache_state_to_token_tensor()
self.final_states = fsm_finals | {self.final_state}

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand All @@ -211,7 +256,7 @@ def get_next_instruction(self, state: int) -> Instruction:
"""
next_tokens_mask = self.states_to_token_mask.get(state)
if next_tokens_mask is None:
return Write(torch.tensor([self.eos_token_id]))
return Write(torch.tensor([self.tokenizer.eos_token_id]))

return Generate(next_tokens_mask)

Expand All @@ -233,13 +278,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
The new state of the guide.

"""
if token_id == self.eos_token_id or state not in self.states_to_token_maps:
return -1
if (
token_id == self.tokenizer.eos_token_id
or state not in self.states_to_token_maps
):
return self.final_state

last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1
next_state = self.final_state

return next_state

Expand Down Expand Up @@ -278,11 +326,11 @@ def create_states_mapping_from_interegular_fsm(
from_interegular_instance.states_to_token_maps,
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
from_interegular_instance.tokenizer = tokenizer
from_interegular_instance._cache_state_to_token_tensor()
return from_interegular_instance

def _cache_state_to_token_tensor(self):
def _cache_state_to_token_tensor(self) -> None:
"""
cache state -> token int tensor
this increases performance of mask construction substantially
Expand All @@ -297,7 +345,7 @@ def is_final_state(self, state: int) -> bool:
return state in self.final_states

def copy(self):
return self
return copy(self)


class CFGGuide(Guide):
Expand Down Expand Up @@ -331,9 +379,6 @@ def __init__(self, cfg_string: str, tokenizer):
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexGuide

self.start_state = 0
self.final_state = -1

def get_next_instruction(self, state: int) -> Instruction:
"""Generate an instruction for the next step.

Expand Down Expand Up @@ -475,3 +520,163 @@ def is_final_state(self, state: int) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the FSM."""
return CFGGuide(self.cfg_string, self.tokenizer)


def align_tokens_states_to_token_maps(
token_ids: List[int],
RobinPicard marked this conversation as resolved.
Show resolved Hide resolved
vocabulary: Dict[str, int],
states_to_token_maps: Dict[int, Dict[int, int]],
) -> Tuple[List[int], Dict[int, Dict[int, int]]]:
"""Apply token alignment to the provided prompt tokens and attention masks given the
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
states_to_token_maps. You can find an explanation from Guidance on why token healing
is necessary here:
https://github.com/guidance-ai/guidance/blob/main/notebooks/tutorials/token_healing.ipynb
"""
crossing_tokens = find_crossing_tokens(token_ids, vocabulary)
valid_crossing_tokens = get_crossing_tokens_target_states(
states_to_token_maps, crossing_tokens, token_ids, vocabulary
)
if not valid_crossing_tokens:
return token_ids, states_to_token_maps
(
states_to_token_maps,
number_cropped_tokens,
) = add_crossing_tokens_states_to_tokens_map(
states_to_token_maps, token_ids, valid_crossing_tokens
)
return (
token_ids[:-number_cropped_tokens],
states_to_token_maps,
)


def find_crossing_tokens(
token_ids: List[int], vocabulary: Dict[str, int]
) -> Dict[int, List[int]]:
"""Find the tokens that could replace one or more tokens at the end of token_ids
while conserving the same intial text (and extending it by at least one character).
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
len_token_ids = len(token_ids)
max_length_token_text = max(len(item) for item in vocabulary.keys())
characters_considered = ""
crossing_tokens_map = {}

for index, token_id in enumerate(reversed(token_ids)):
characters_considered = reversed_vocabulary[token_id] + characters_considered
if len(characters_considered) >= max_length_token_text:
break
crossing_token_ids = [
token_id
for text, token_id in vocabulary.items()
if text.startswith(characters_considered)
and len(text) > len(characters_considered)
]
if crossing_token_ids:
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids

return crossing_tokens_map


def get_crossing_tokens_target_states(
states_to_tokens_map: Dict[int, Dict[int, int]],
crossing_tokens: Dict[int, List[int]],
prompt_token_ids: List[int],
vocabulary: Dict[str, int],
) -> Dict[int, Dict[int, int]]:
"""For each crossing token associated to an index, check that the characters after the boundary
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
provided indexes, the associated valid tokens with the state they would lead to.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
prompt_token_texts = [
reversed_vocabulary[token_id] for token_id in prompt_token_ids
]

valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
for pos, tokens in crossing_tokens.items():
for token in tokens:
is_valid = True
characters = reversed_vocabulary[token]
characters_before_border = "".join(prompt_token_texts[pos:])
characters_after_border = characters[len(characters_before_border) :]
state = 0
for char in characters_after_border:
char_token = vocabulary.get(char)
try:
state = states_to_tokens_map[state][char_token] # type: ignore
except KeyError:
is_valid = False
break
if is_valid:
valid_crossing_tokens[pos][token] = state

return valid_crossing_tokens


def add_crossing_tokens_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
prompt_token_ids: List[int],
crossing_tokens_map: Dict[int, Dict[int, int]],
) -> Tuple[Dict[int, Dict[int, int]], int]:
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
the starting state of the fsm as we would include some characters at the end of the prompt in
the states_to_tokens_map.
Attention! the starting state of the states_to_tokens_map provided must be 0.
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
"""
if not crossing_tokens_map:
return states_to_tokens_map, 0
first_crossing_token_pos = min(
[key for key, value in crossing_tokens_map.items() if value]
)
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
highest_state = max(
max(states_to_tokens_map.keys()),
max(max(items.values()) for items in states_to_tokens_map.values()),
)

for i in range(number_additional_states):
# add the tokens that was originally part of the prompt
if i == number_additional_states - 1:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: 0
}
else:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
}
# add the crossing tokens
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
if crossing_tokens:
for token, target_state in crossing_tokens.items():
states_to_tokens_map[highest_state + 1 + i][token] = target_state

# set the id of our new initial state to 0
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
states_to_tokens_map, highest_state + 1, 0
)
return states_to_tokens_map, number_additional_states


def swap_state_ids_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
first_state_id: int,
second_state_id: int,
) -> Dict[int, Dict[int, int]]:
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
first_state_transitions = states_to_tokens_map.pop(first_state_id)
second_state_transitions = states_to_tokens_map.pop(second_state_id)
states_to_tokens_map[first_state_id] = second_state_transitions
states_to_tokens_map[second_state_id] = first_state_transitions

for transitions in states_to_tokens_map.values():
for token, target_state_id in list(transitions.items()):
if target_state_id == first_state_id:
transitions[token] = second_state_id
elif target_state_id == second_state_id:
transitions[token] = first_state_id

return states_to_tokens_map
Loading