Skip to content

Commit

Permalink
Implement prompt token alignment in FSMLogitsProcessor and in Sequenc…
Browse files Browse the repository at this point in the history
…eGeneratorAdapter
  • Loading branch information
RobinPicard committed Jul 23, 2024
1 parent 017597a commit 4f640c4
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 24 deletions.
132 changes: 124 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import Iterator, List, Optional, Sequence, Union

import torch

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]
TotalCompletionsType = Optional[Union[List[str], str]]


class SequenceGenerator:
Expand Down Expand Up @@ -461,6 +461,47 @@ def prepare_generation_parameters(

return generation_params

def strip_completions(
self,
completions,
prompts: Union[str, List[str]],
aligned_prompts: Union[str, List[str]],
):
"""Remove characters generated through token alignment from the completions.
As token alignment makes the model re-generate some of the characters at
the end of the prompt, we want to remove those from the beginning of the
completions to only return the characters after the end of the user prompts.
Parameters
----------
completions
Text generated by the model
prompts
The original prompts provided by the user
aligned_prompts
The prompts of the user after token alignment (what's given to the model)
Returns
-------
The stripped completions
"""
if isinstance(prompts, str):
if isinstance(completions, str):
return completions[len(prompts) - len(aligned_prompts) :]

return [
self.strip_completions(completion, prompts, aligned_prompts)
for completion in completions
]

return [
self.strip_completions(completion, prompt, aligned_prompt)
for completion, prompt, aligned_prompt in zip(
completions, prompts, aligned_prompts
)
]

def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.
Expand All @@ -485,6 +526,7 @@ def __call__(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
token_healing_enabled=True,
**model_specific_params,
):
"""Generate text from a prompt of list of prompts."""
Expand All @@ -500,32 +542,106 @@ def format(sequences):
max_tokens, stop_at, seed
)

# if token_healing is disabled or unavailable for the type of fsm used by the processor,
# the aligned_prompts are just the prompts
aligned_prompts = self.logits_processor.setup_processor(
prompts, token_healing_enabled
)

completions = self.model.generate(
prompts,
aligned_prompts,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

return format(completions)
stripped_completions = self.strip_completions(
completions, prompts, aligned_prompts
)

return format(stripped_completions)

def stream(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
token_healing_enabled=True,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""

def add_chunks_to_completions(
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
total_completions: Optional[
Union[str, List[str], List[List[str]], Sequence[str]]
],
):
"""Append each of the text chunks at the end of the corresponding completions"""
if isinstance(text_chunks, str):
if isinstance(total_completions, str):
return total_completions + text_chunks
return text_chunks

if total_completions:
return [
add_chunks_to_completions(text_chunk, total_completion)
for text_chunk, total_completion in zip(
text_chunks, total_completions
)
]

return [
add_chunks_to_completions(text_chunk, None)
for text_chunk in text_chunks
]

def strip_text_chunks(
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
stripped_completions: Union[str, List[str], List[List[str]], Sequence[str]],
):
"""Get the stripped text_chunks from the stripped_completions."""
if isinstance(text_chunks, str):
return (
stripped_completions[-len(text_chunks) :]
if len(text_chunks) > 0
else ""
)

return [
strip_text_chunks(text_chunk, stripped_completion)
for text_chunk, stripped_completion in zip(
text_chunks, stripped_completions
)
]

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(

# if token_healing is disabled or unavailable for the type of fsm used by the processor,
# the aligned_prompts are just the prompts
aligned_prompts = self.logits_processor.setup_processor(
prompts, token_healing_enabled
)

total_completions: TotalCompletionsType = None

for text_chunks in self.model.stream(
prompts,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)
):
total_completions = add_chunks_to_completions(
text_chunks, total_completions
)

stripped_completions = self.strip_completions(
total_completions, prompts, aligned_prompts
)

yield strip_text_chunks(text_chunks, stripped_completions)
90 changes: 75 additions & 15 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
limitations under the License.
"""
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeGuard, Union

import torch
from pydantic import BaseModel

from outlines.fsm.guide import CFGGuide, Guide, RegexGuide, StopAtEOSGuide
from outlines.fsm.guide import (
CFGGuide,
Guide,
RegexGuide,
StopAtEOSGuide,
TokenHealerMixin,
)
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str

Expand Down Expand Up @@ -61,8 +67,10 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_states: Dict[int, int] = {hash(tuple([])): 0}
self._fsm_states: List[Dict[int, int]] = []
self.fsm: Guide = fsm
self._seq_fsms: List[Guide] = []
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
Expand All @@ -82,36 +90,87 @@ def process_logits(
torch.Tensor
The biased logits.
"""
if self._seq_start_idx is None:
samples = int(len(input_ids) / len(self._seq_fsms))
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
self._seq_start_idx = len(input_ids[0])

sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
self._fsm_states = [
{hash(tuple([])): 0} for _ in range(len(self._seq_fsms))
]
sequence_states = [0] * len(input_ids)

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))
else:
for i, seq_ids in enumerate(input_ids):
try:
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[i // samples][prev_state_key]

if curr_state_key not in self._fsm_states:
prev_state = self._fsm_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.fsm.get_next_state(prev_state, gen_ids[-1])
self._fsm_states[curr_state_key] = curr_state
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self._seq_fsms[i // samples].get_next_state(
prev_state, seq_ids[-1]
)

sequence_states.append(self._fsm_states[curr_state_key])
self._fsm_states[i // samples][curr_state_key] = curr_state
sequence_states.append(curr_state)

# This exception happens after the sequence generation is finished with bean search
except KeyError:
sequence_states.append(self._seq_fsms[i // samples].final_state)

mask = torch.full_like(logits, -math.inf)
for i, fsm_state in enumerate(sequence_states):
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
allowed_tokens = (
self._seq_fsms[i // samples].get_next_instruction(fsm_state).tokens
)
mask[i, allowed_tokens] = logits[i, allowed_tokens]

return mask

def setup_processor(
self, prompts: Union[str, List[str]], token_healing_enabled: bool
) -> Union[str, List[str]]:
"""Prepare the processor to process logits for a specific set of prompts. Create a distinct
fsm for each prompt. If selected and available, apply prompt alignment to each fsm.
Parameters
----------
prompts
The text prompts previded by the user
Returns
-------
The initial prompts after application of prompt alignment if selected and available,
the initial prompts unchanged otherwise.
"""
is_input_str = isinstance(prompts, str)
if isinstance(prompts, str):
prompts = [prompts]

self._seq_fsms = [self.fsm.copy() for _ in range(len(prompts))]

if isinstance(self.fsm, TokenHealerMixin) and token_healing_enabled:
aligned_prompts = [
fsm.align_prompt_tokens(prompt) # type: ignore
for fsm, prompt in zip(self._seq_fsms, prompts)
]
else:
aligned_prompts = prompts

if is_input_str:
return aligned_prompts[0]
return aligned_prompts

def copy(self) -> "FSMLogitsProcessor":
"""Return a copy of the logits processor."""
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())


class TextLogitsProcessor(FSMLogitsProcessor):
"""Bias generation for free text (required because of prompt alignment).
Attributes
----------
tokenizer
Expand All @@ -122,6 +181,7 @@ class TextLogitsProcessor(FSMLogitsProcessor):

def __init__(self, tokenizer: "Tokenizer"):
"""Compile the FSM that drives the regex-guided generation.
Parameters
----------
tokenizer
Expand Down Expand Up @@ -213,4 +273,4 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
The tokenizer used to convert tokens to ids.
"""
cfg_automata = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
super().__init__(tokenizer=tokenizer, fsm=cfg_automata)
super().__init__(tokenizer=tokenizer, fsm=cfg_automata)
17 changes: 16 additions & 1 deletion tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def test_sequence_generator_class():
class MockFSM:
first_state = 0

def align_prompt_tokens(self, prompt):
return prompt

def get_next_state(self, state, next_token_ids):
return 4

Expand All @@ -39,7 +42,7 @@ def encode(self, _):
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])

def decode(self, tokens):
return ["testx"[i] for i in tokens]
return ["".join(["testx"[int(i)] for i in tokens[0]])]

class MockModel:
def __init__(self):
Expand Down Expand Up @@ -77,6 +80,9 @@ def __call__(self, biased_logits, *_):

def test_sequence_generator_1d_single_iteration():
class MockFSM:
def align_prompt_tokens(self, prompt):
return prompt

def get_next_state(self, state, next_token_ids):
return 0

Expand Down Expand Up @@ -132,6 +138,9 @@ def sampler(biased_logits, *_):

def test_sequence_generator_1d_several_iterations():
class MockFSM:
def align_prompt_tokens(self, prompt):
return prompt

def get_next_state(self, state, next_token_ids):
return state + 1

Expand Down Expand Up @@ -194,6 +203,9 @@ def sampler(biased_logits, *_):

def test_sequence_generator_2d_single_iteration():
class MockFSM:
def align_prompt_tokens(self, prompt):
return prompt

def get_next_state(self, state, next_token_ids):
return 0

Expand Down Expand Up @@ -260,6 +272,9 @@ def sampler(biased_logits, *_):

def test_sequence_generator_2d_several_iterations():
class MockFSM:
def align_prompt_tokens(self, prompt):
return prompt

def get_next_state(self, state, next_token_ids):
return state + 1

Expand Down

0 comments on commit 4f640c4

Please sign in to comment.