Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use token_ids to track the FSM state for each sequence in the vLLM integration #539

Closed
wants to merge 16 commits into from
110 changes: 65 additions & 45 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Make vLLM compatible with Outlines' guided generation."""
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
from typing import Dict, List, Tuple

import torch

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.fsm import FSMState, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


Expand All @@ -29,7 +28,7 @@ def _patched_apply_logits_processors(
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(seq_id, token_ids, logits_row)
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
Expand All @@ -39,7 +38,58 @@ def _patched_apply_logits_processors(
return logits


def adapt_tokenizer(tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition, we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class CachedRegexFSM(RegexFSM):
def __init__(self, regex_string: str, adapted_tokenizer):
super().__init__(regex_string, adapted_tokenizer)
self.state_cache: Dict[int, FSMState] = {}

def get_state_by_token_ids(self, input_ids: Tuple[int]) -> FSMState:
state_key = hash(input_ids)

if not input_ids:
self.state_cache[state_key] = FSMState(0)

elif state_key not in self.state_cache:
prev_state_key = hash(input_ids[:-1])
prev_state = self.state_cache[prev_state_key]

last_token = input_ids[-1]
new_state = self.next_state(prev_state, last_token)
self.state_cache[state_key] = new_state

return self.state_cache[state_key]


class RegexLogitsProcessor:
fsm_cache: Dict[str, CachedRegexFSM] = {}

def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.

Expand All @@ -51,58 +101,26 @@ def __init__(self, regex_string, llm):
An instance of `vllm.LLM`

"""
tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer)
adapted_tokenizer = adapt_tokenizer(llm.tokenizer.tokenizer)

fsm = self.fsm_cache.get(regex_string)
if fsm is None:
fsm = CachedRegexFSM(regex_string, adapted_tokenizer)
self.fsm_cache[regex_string] = fsm

fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
state = self.fsm.get_state_by_token_ids(tuple(input_ids))
allowed_tokens = self.fsm.allowed_token_ids(state)

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, llm):
Expand All @@ -118,5 +136,7 @@ def __init__(self, schema, llm):
"""
if isinstance(schema, dict):
schema = json.dumps(schema)

regex_string = build_regex_from_object(schema)

super().__init__(regex_string, llm)
118 changes: 118 additions & 0 deletions tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import re

import torch

from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors


class MockTokenizer:
vocabulary = {
**{chr(i): i for i in range(256)},
**{"eos": 256},
}
special_tokens = {"eos"}
eos_token_id = 256

@property
def inverse_vocabulary(self):
return {v: k for k, v in self.vocabulary.items()}

def decode(self, token_ids):
return "".join([self.inverse_vocabulary[t] for t in token_ids])

####
# vLLM tokenizer features
####
all_special_tokens = list(special_tokens)

def convert_tokens_to_string(self, token):
return token[0]

def get_vocab(self):
return MockTokenizer.vocabulary


class MockTokenizerGroup:
tokenizer = MockTokenizer()


class MockModel:
tokenizer = MockTokenizerGroup()


def sample_from_logits(logits):
probs = torch.exp(logits) / torch.sum(torch.exp(logits))
return torch.multinomial(probs, 1).item()


def test_time_regexp():
pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
llm = MockModel()
logits_processor = RegexLogitsProcessor(pattern, llm)

token_ids = []
while True:
random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocabulary))
logits = logits_processor(
input_ids=token_ids,
scores=random_scores,
)
new_token_id = sample_from_logits(logits)
if new_token_id == llm.tokenizer.eos_token_id:
break
token_ids.append(new_token_id)

assert re.fullmatch(pattern, llm.tokenizer.decode(token_ids)) is not None


def test_time_regexp_multiple_samples():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this testing?

Copy link
Contributor Author

@lapp0 lapp0 Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed a lack of stability in sequence order when using beam search with Outlines. This resulted in a new token for one sequence being applied to a different sequence.

This test reproduces that behavior. It fails on main and passes with these changes.

I will leave an explanatory doc string.

num_seq = 64

pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\ ?(am|pm)?"
llm = MockModel()

class MockSeqData:
def __init__(self):
self.output_token_ids = []

class MockSamplingParams:
logits_processors = [RegexLogitsProcessor(pattern, llm)]

class MockSamplingMeta:
seq_groups = [[range(num_seq), MockSamplingParams()]] # seq_ids
seq_data = {seq_id: MockSeqData() for seq_id in range(num_seq)}

sampling_meta = MockSamplingMeta()

results = []
while True:
complete_seq_ids = set()

logits = torch.randn(len(sampling_meta.seq_data), len(llm.tokenizer.vocabulary))
new_logits = _patched_apply_logits_processors(logits, sampling_meta)
seq_ids = sorted(sampling_meta.seq_groups[0][0])
for logits_row, seq_id in zip(new_logits, seq_ids):
new_token_id = sample_from_logits(logits_row)
if new_token_id == llm.tokenizer.eos_token_id:
complete_seq_ids.add(seq_id)
results.append(sampling_meta.seq_data[seq_id].output_token_ids)
else:
sampling_meta.seq_data[seq_id].output_token_ids.append(new_token_id)

if complete_seq_ids:
seq_datas = [
sd
for seq_id, sd in sampling_meta.seq_data.items()
if seq_id not in complete_seq_ids
]
sampling_meta.seq_data = {
i: seq_data for i, seq_data in enumerate(seq_datas)
}
sampling_meta.seq_groups[0][0] = range(len(sampling_meta.seq_data))

if not sampling_meta.seq_data:
break

assert len(results) == num_seq
for result in results:
assert re.fullmatch(pattern, llm.tokenizer.decode(result)) is not None
Loading