diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index b7b121fe6..ec3f9b53c 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -90,6 +90,43 @@ def is_final_state(self, state: Any) -> bool: def copy(self) -> "Guide": ... + def accepts(self, token_ids: List[int], state=None) -> bool: + """ + Determine whether the sequence, `token_ids`, is accepted by the Guide. + `token_ids` doesn't need to complete the guide to be accepted. + """ + try: + self.derive(token_ids, state) + return True + except ValueError: + return False + + def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]: + """ + TODO: Docstring + """ + if state is None: + state = self.initial_state + for token_id in token_ids: + instruction = self.get_next_instruction(state) + + # determine if token_id allowed by instruction + if isinstance(instruction, Write): + raise NotImplementedError("TODO") + elif isinstance(instruction, Generate): + if ( + instruction.tokens is not None + and token_id not in instruction.tokens + ): + raise ValueError("Cannot advance state with provided token_ids") + else: + raise TypeError(f"Expected instruction, got {instruction}") + + # advance state + state = self.get_next_state(state, token_id) + + return state + class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" @@ -487,3 +524,214 @@ def must_terminate_state(self, state: CFGState) -> bool: def copy(self) -> "CFGGuide": """Create a copy of the Guide.""" return CFGGuide(self.cfg_string, self.tokenizer) + + +@cache() +def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]: + """Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]""" + + # precompute the token ids of all vocab suffixes + suffixes = list( + {tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))} + ) + encoded_suffixes, _ = tokenizer.encode(suffixes) + encoded_suffixes = [ + [tok for tok in seq_ids if tok != tokenizer.pad_token_id] + for seq_ids in encoded_suffixes.tolist() + ] + suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes))) + suffix_map[""] = tuple() + + # compute prefix-suffix map for all tokens, s.t. prefix + suffix = token + prefix_map = collections.defaultdict(set) + for token, token_id in tokenizer.vocabulary.items(): + for i in range(1, len(token) + 1): + prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]])) + return prefix_map + + +AlignmentGuideState = collections.namedtuple( + "AlignmentGuideState", ["legal_path_map", "child_guide_state"] +) + + +class AlignmentGuide(Guide): + def __init__( + self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None + ): + """ + Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide. + + Parameters + ---------- + prompt : str + The prompt text to be aligned with the generated tokens. + tokenizer : Tokenizer + Tokenizer used to align the prompt. + child_guide : Guide, optional + A guide to take control after alignment is complete. None -> Unconstrained after alignment + """ + self.prompt = prompt + self.tokenizer = tokenizer + self.child_guide = child_guide + + alignment_seqs, child_guide_ids = self._get_alignment_sequences( + prompt, tokenizer, child_guide + ) + alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids( + alignment_seqs + ) + + self.alignment_prompt = self.tokenizer.decode( + [alignment_seqs[0, :common_prompt_len]] + )[0] + + # calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens + legal_paths = [ + tuple([t for t in seq if t != tokenizer.pad_token_id]) + for seq in alignment_seqs[:, common_prompt_len:].tolist() + ] + legal_path_map = dict(zip(legal_paths, child_guide_ids)) + + self.initial_state = AlignmentGuideState( + legal_path_map=legal_path_map, child_guide_state=None + ) + + @staticmethod + def _get_alignment_sequences( + prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None + ): + """ + Calculate all possible sequences which are valid with a prompt + child_guide + E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid + + Returns tuple of (alignment_seqs, child_guide_ids) of same length + - alignment_seqs: + All token sequences which can represent `prompt` + start of generation. The last token + must represent the end of the prompt can extend beyond the prompt to start generation. + Sequences are only included if the start of generation portion is legal with child guide. + - child_guide_ids: + Token to send to the child guide to simulate the start of generation. In the example above + "world" is the last alignment seq token, therefore we must advance the state of the child + guide with the tokenization of "rld" in order to continue generation with the child guide. + """ + guide_accepts: Dict[ + Tuple[int], bool + ] = {} # cache of suffix acceptance for child_guide.accepts() + + # prompts with alignment tokens at end + aligned_prompt_completions: List[str] = [] + # tokens to feed child guide once alignment completes + child_guide_ids: List[Tuple] = [] + + # compute alignment seqs which are valid with prompt and child guide + for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items(): + if prompt.endswith(prefix): + for suffix, suffix_ids in alignment_details: + if child_guide is None: + aligned_prompt_completions.append(prompt + suffix) + child_guide_ids.append(tuple()) + elif guide_accepts.setdefault( + suffix_ids, child_guide.accepts(suffix_ids) + ): + aligned_prompt_completions.append(prompt + suffix) + child_guide_ids.append(suffix_ids) + + alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions) + return alignment_seqs, child_guide_ids + + @staticmethod + def _get_longest_common_prompt_ids(alignment_seqs): + """ + Among all candidate prompt alignment seqs, get the longest shared prefix and their length + """ + # get longest common prefix among alignment sequences, which will form our alignment prompt + common = ( + (alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0)) + .all(0) + .cumprod(1) + ) + common_len = common.sum(1).max().item() + return alignment_seqs[0, :common_len], common_len + + def get_next_instruction(self, state: AlignmentGuideState) -> Instruction: + """ + Return the next set of valid tokens for generation based on the current state. + + If alignment hasn't completed: + tokens which continue one of the candidate alignment paths are legal + If alignment has completed: + get instruction from the child guide + """ + if state.legal_path_map is not None: + return Generate( + sorted({token_ids[0] for token_ids in state.legal_path_map.keys()}) + ) + elif self.child_guide is None: + return Generate(None) + else: + return self.child_guide.get_next_instruction(state.child_guide_state) + + def get_next_state( + self, state: AlignmentGuideState, token_id: int + ) -> AlignmentGuideState: + """ + Get AlignmentGuideState advanced by token ID. + + If alignment has completed: + get instruction from the child guide + If alignment hasn't completed: + Filter out alignment paths which don't start with token_id + Remove First token from remaining paths + If advancing the state completes alignment: + Advance the child_guide state + """ + if state.legal_path_map is None: + if self.child_guide is not None: + return AlignmentGuideState( + legal_path_map=None, + child_guide_state=self.child_guide.get_next_state( + state.child_guide_state, token_id + ), + ) + else: + return AlignmentGuideState(None, None) + else: + next_state_legal_path_map = { + key[1:]: value + for key, value in state.legal_path_map.items() + if key[0] == token_id + } + # if none remaining, advance the child guide + if not any(next_state_legal_path_map): + if self.child_guide is not None: + child_guide_advancement_ids = next( + iter(next_state_legal_path_map.values()) + ) + return AlignmentGuideState( + legal_path_map=None, + child_guide_state=self.child_guide.derive( + child_guide_advancement_ids, state.child_guide_state + ), + ) + else: + return AlignmentGuideState(None, None) + + # if paths remaining, return advanced legal_path_map + else: + return AlignmentGuideState( + legal_path_map=next_state_legal_path_map, + child_guide_state=state.child_guide_state, + ) + + def is_final_state(self, state: AlignmentGuideState) -> bool: + if state.legal_path_map is not None: + return False + elif self.child_guide is None: + return True + else: + return self.child_guide.is_final_state(state.child_guide_state) + + def copy(self): + """AlignmentGuide isn't mutated""" + return self diff --git a/tests/fsm/test_alignment_guide.py b/tests/fsm/test_alignment_guide.py new file mode 100644 index 000000000..3d823b7a7 --- /dev/null +++ b/tests/fsm/test_alignment_guide.py @@ -0,0 +1,87 @@ +import torch +from transformers import AutoTokenizer + +from outlines.fsm.guide import AlignmentGuide, RegexGuide +from outlines.models.transformers import TransformerTokenizer + + +class MockTokenizer: + def __init__(self, vocabulary): + self.vocabulary = {tok: i for i, tok in enumerate(vocabulary)} + self.vocabulary[""] = len(self.vocabulary) + self.special_tokens = {""} + self.eos_token_id = self.vocabulary[""] + self.pad_token_id = -1 + + self.inverse_vocabulary = {i: tok for tok, i in self.vocabulary.items()} + + def convert_token_to_string(self, token): + return token + + def decode(self, token_ids): + if token_ids == []: + return "" + if isinstance(list(token_ids)[0], list): + return [ + "".join(map(self.inverse_vocabulary.get, token_ids_sublist)) + for token_ids_sublist in token_ids + ] + return [self.inverse_vocabulary[int(token_id)] for token_id in token_ids] + + def encode(self, texts): + """ + Encodes the input texts by finding the longest matching tokens in the vocabulary. + """ + seqs = [] + for text in texts: + tokens = [] + while text: + token = next( + ( + tok + for tok in sorted(self.vocabulary, key=len, reverse=True) + if text.startswith(tok) + ), + None, + ) + if token is None: + tokens = [self.pad_token_id] + break + tokens.append(self.vocabulary[token]) + text = text[len(token) :] + seqs.append(tokens) + + max_len = max(len(seq) for seq in seqs) + padded_seqs = torch.tensor( + [seq + [self.pad_token_id] * (max_len - len(seq)) for seq in seqs] + ) + return padded_seqs, None + + +def test_alignment_with_pseudo_token_and_regex_guide(): + # Mock tokenizer with the vocabulary for "hello", "world", "wo", "rld", and "!" + tokenizer = MockTokenizer(["hello", " world", " wo", "rld", "!"]) + prompt = "hello wo" + + # Create a RegexGuide that expects the sequence "rld!" + child_guide = RegexGuide(regex_string="rld!", tokenizer=tokenizer) + + # Create the AlignmentGuide with the child guide + guide = AlignmentGuide(prompt, tokenizer, child_guide=child_guide) + + assert guide.alignment_prompt == "hello" + + # assert " world!" is legal and final + seq = [tokenizer.vocabulary[" world"], tokenizer.vocabulary["!"]] + assert guide.accepts(seq) + assert guide.is_final_state(guide.derive(seq, guide.initial_state)) is True + + +def test_alignment_guide_gpt2_url(): + # Based on notebook + # https://github.com/guidance-ai/guidance/blob/af63e6/notebooks/tutorials/token_healing.ipynb#L4 + tokenizer = TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) + prompt = "The url of Google is http:" + guide = AlignmentGuide(prompt, tokenizer) + assert guide.alignment_prompt == "The url of Google is http" + assert guide.accepts(list(tokenizer.encode("://google.com")[0][0]))