Skip to content

Commit

Permalink
Implement AlignmentGuide
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Sep 21, 2024
1 parent 289ef5d commit 54aaf87
Show file tree
Hide file tree
Showing 2 changed files with 335 additions and 0 deletions.
248 changes: 248 additions & 0 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,43 @@ def is_final_state(self, state: Any) -> bool:
def copy(self) -> "Guide":
...

def accepts(self, token_ids: List[int], state=None) -> bool:
"""
Determine whether the sequence, `token_ids`, is accepted by the Guide.
`token_ids` doesn't need to complete the guide to be accepted.
"""
try:
self.derive(token_ids, state)
return True
except ValueError:
return False

def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]:
"""
TODO: Docstring
"""
if state is None:
state = self.initial_state
for token_id in token_ids:
instruction = self.get_next_instruction(state)

# determine if token_id allowed by instruction
if isinstance(instruction, Write):
raise NotImplementedError("TODO")
elif isinstance(instruction, Generate):
if (
instruction.tokens is not None
and token_id not in instruction.tokens
):
raise ValueError("Cannot advance state with provided token_ids")
else:
raise TypeError(f"Expected instruction, got {instruction}")

# advance state
state = self.get_next_state(state, token_id)

return state


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""
Expand Down Expand Up @@ -487,3 +524,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the Guide."""
return CFGGuide(self.cfg_string, self.tokenizer)


@cache()
def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]:
"""Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""

# precompute the token ids of all vocab suffixes
suffixes = list(
{tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))}
)
encoded_suffixes, _ = tokenizer.encode(suffixes)
encoded_suffixes = [
[tok for tok in seq_ids if tok != tokenizer.pad_token_id]
for seq_ids in encoded_suffixes.tolist()
]
suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes)))
suffix_map[""] = tuple()

# compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
prefix_map = collections.defaultdict(set)
for token, token_id in tokenizer.vocabulary.items():
for i in range(1, len(token) + 1):
prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]]))
return prefix_map


AlignmentGuideState = collections.namedtuple(
"AlignmentGuideState", ["legal_path_map", "child_guide_state"]
)


class AlignmentGuide(Guide):
def __init__(
self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
):
"""
Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
Parameters
----------
prompt : str
The prompt text to be aligned with the generated tokens.
tokenizer : Tokenizer
Tokenizer used to align the prompt.
child_guide : Guide, optional
A guide to take control after alignment is complete. None -> Unconstrained after alignment
"""
self.prompt = prompt
self.tokenizer = tokenizer
self.child_guide = child_guide

alignment_seqs, child_guide_ids = self._get_alignment_sequences(
prompt, tokenizer, child_guide
)
alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids(
alignment_seqs
)

self.alignment_prompt = self.tokenizer.decode(
[alignment_seqs[0, :common_prompt_len]]
)[0]

# calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
legal_paths = [
tuple([t for t in seq if t != tokenizer.pad_token_id])
for seq in alignment_seqs[:, common_prompt_len:].tolist()
]
legal_path_map = dict(zip(legal_paths, child_guide_ids))

self.initial_state = AlignmentGuideState(
legal_path_map=legal_path_map, child_guide_state=None
)

@staticmethod
def _get_alignment_sequences(
prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
):
"""
Calculate all possible sequences which are valid with a prompt + child_guide
E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
Returns tuple of (alignment_seqs, child_guide_ids) of same length
- alignment_seqs:
All token sequences which can represent `prompt` + start of generation. The last token
must represent the end of the prompt can extend beyond the prompt to start generation.
Sequences are only included if the start of generation portion is legal with child guide.
- child_guide_ids:
Token to send to the child guide to simulate the start of generation. In the example above
"world" is the last alignment seq token, therefore we must advance the state of the child
guide with the tokenization of "rld" in order to continue generation with the child guide.
"""
guide_accepts: Dict[
Tuple[int], bool
] = {} # cache of suffix acceptance for child_guide.accepts()

# prompts with alignment tokens at end
aligned_prompt_completions: List[str] = []
# tokens to feed child guide once alignment completes
child_guide_ids: List[Tuple] = []

# compute alignment seqs which are valid with prompt and child guide
for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items():
if prompt.endswith(prefix):
for suffix, suffix_ids in alignment_details:
if child_guide is None:
aligned_prompt_completions.append(prompt + suffix)
child_guide_ids.append(tuple())
elif guide_accepts.setdefault(
suffix_ids, child_guide.accepts(suffix_ids)
):
aligned_prompt_completions.append(prompt + suffix)
child_guide_ids.append(suffix_ids)

alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions)
return alignment_seqs, child_guide_ids

@staticmethod
def _get_longest_common_prompt_ids(alignment_seqs):
"""
Among all candidate prompt alignment seqs, get the longest shared prefix and their length
"""
# get longest common prefix among alignment sequences, which will form our alignment prompt
common = (
(alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0))
.all(0)
.cumprod(1)
)
common_len = common.sum(1).max().item()
return alignment_seqs[0, :common_len], common_len

def get_next_instruction(self, state: AlignmentGuideState) -> Instruction:
"""
Return the next set of valid tokens for generation based on the current state.
If alignment hasn't completed:
tokens which continue one of the candidate alignment paths are legal
If alignment has completed:
get instruction from the child guide
"""
if state.legal_path_map is not None:
return Generate(
sorted({token_ids[0] for token_ids in state.legal_path_map.keys()})
)
elif self.child_guide is None:
return Generate(None)
else:
return self.child_guide.get_next_instruction(state.child_guide_state)

def get_next_state(
self, state: AlignmentGuideState, token_id: int
) -> AlignmentGuideState:
"""
Get AlignmentGuideState advanced by token ID.
If alignment has completed:
get instruction from the child guide
If alignment hasn't completed:
Filter out alignment paths which don't start with token_id
Remove First token from remaining paths
If advancing the state completes alignment:
Advance the child_guide state
"""
if state.legal_path_map is None:
if self.child_guide is not None:
return AlignmentGuideState(
legal_path_map=None,
child_guide_state=self.child_guide.get_next_state(
state.child_guide_state, token_id
),
)
else:
return AlignmentGuideState(None, None)
else:
next_state_legal_path_map = {
key[1:]: value
for key, value in state.legal_path_map.items()
if key[0] == token_id
}
# if none remaining, advance the child guide
if not any(next_state_legal_path_map):
if self.child_guide is not None:
child_guide_advancement_ids = next(
iter(next_state_legal_path_map.values())
)
return AlignmentGuideState(
legal_path_map=None,
child_guide_state=self.child_guide.derive(
child_guide_advancement_ids, state.child_guide_state
),
)
else:
return AlignmentGuideState(None, None)

# if paths remaining, return advanced legal_path_map
else:
return AlignmentGuideState(
legal_path_map=next_state_legal_path_map,
child_guide_state=state.child_guide_state,
)

def is_final_state(self, state: AlignmentGuideState) -> bool:
if state.legal_path_map is not None:
return False
elif self.child_guide is None:
return True
else:
return self.child_guide.is_final_state(state.child_guide_state)

def copy(self):
"""AlignmentGuide isn't mutated"""
return self
87 changes: 87 additions & 0 deletions tests/fsm/test_alignment_guide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from transformers import AutoTokenizer

from outlines.fsm.guide import AlignmentGuide, RegexGuide
from outlines.models.transformers import TransformerTokenizer


class MockTokenizer:
def __init__(self, vocabulary):
self.vocabulary = {tok: i for i, tok in enumerate(vocabulary)}
self.vocabulary["<eos>"] = len(self.vocabulary)
self.special_tokens = {"<eos>"}
self.eos_token_id = self.vocabulary["<eos>"]
self.pad_token_id = -1

self.inverse_vocabulary = {i: tok for tok, i in self.vocabulary.items()}

def convert_token_to_string(self, token):
return token

def decode(self, token_ids):
if token_ids == []:
return ""
if isinstance(list(token_ids)[0], list):
return [
"".join(map(self.inverse_vocabulary.get, token_ids_sublist))
for token_ids_sublist in token_ids
]
return [self.inverse_vocabulary[int(token_id)] for token_id in token_ids]

def encode(self, texts):
"""
Encodes the input texts by finding the longest matching tokens in the vocabulary.
"""
seqs = []
for text in texts:
tokens = []
while text:
token = next(
(
tok
for tok in sorted(self.vocabulary, key=len, reverse=True)
if text.startswith(tok)
),
None,
)
if token is None:
tokens = [self.pad_token_id]
break
tokens.append(self.vocabulary[token])
text = text[len(token) :]
seqs.append(tokens)

max_len = max(len(seq) for seq in seqs)
padded_seqs = torch.tensor(
[seq + [self.pad_token_id] * (max_len - len(seq)) for seq in seqs]
)
return padded_seqs, None


def test_alignment_with_pseudo_token_and_regex_guide():
# Mock tokenizer with the vocabulary for "hello", "world", "wo", "rld", and "!"
tokenizer = MockTokenizer(["hello", " world", " wo", "rld", "!"])
prompt = "hello wo"

# Create a RegexGuide that expects the sequence "rld!"
child_guide = RegexGuide(regex_string="rld!", tokenizer=tokenizer)

# Create the AlignmentGuide with the child guide
guide = AlignmentGuide(prompt, tokenizer, child_guide=child_guide)

assert guide.alignment_prompt == "hello"

# assert " world!" is legal and final
seq = [tokenizer.vocabulary[" world"], tokenizer.vocabulary["!"]]
assert guide.accepts(seq)
assert guide.is_final_state(guide.derive(seq, guide.initial_state)) is True


def test_alignment_guide_gpt2_url():
# Based on notebook
# https://github.com/guidance-ai/guidance/blob/af63e6/notebooks/tutorials/token_healing.ipynb#L4
tokenizer = TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2"))
prompt = "The url of Google is http:"
guide = AlignmentGuide(prompt, tokenizer)
assert guide.alignment_prompt == "The url of Google is http"
assert guide.accepts(list(tokenizer.encode("://google.com")[0][0]))

0 comments on commit 54aaf87

Please sign in to comment.