Skip to content

Commit

Permalink
[Frontend][Core] Update Outlines Integration from FSM to Guide (v…
Browse files Browse the repository at this point in the history
…llm-project#4109)

Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Breno Faria <[email protected]>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent 3a6ae1d commit 7b0a0df
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 48 deletions.
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.1
outlines == 0.0.34 # Requires torch >= 2.1.0
outlines >= 0.0.43 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
2 changes: 0 additions & 2 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_guided_logits_processors():
tokenizer,
whitespace_pattern=None)

regex_LP.init_state()
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
tensor = torch.rand(32000)
Expand All @@ -72,7 +71,6 @@ def test_guided_logits_processors():
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

json_LP.init_state()
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
tensor = torch.rand(32000)
Expand Down
31 changes: 12 additions & 19 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import concurrent.futures
from copy import copy
from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
Expand Down Expand Up @@ -54,8 +52,10 @@ class GuidedDecodingMode(Enum):


async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
request: Union[CompletionRequest,
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
Expand All @@ -64,23 +64,17 @@ async def get_outlines_guided_decoding_logits_processor(
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
if not guide:
if not guide or not mode:
return None

if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()

result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode,
request.guided_whitespace_pattern)

logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern)


def _get_guide_and_mode(
Expand Down Expand Up @@ -115,11 +109,10 @@ def _get_guide_and_mode(
return None, None


@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
def _get_logits_processor(
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
Expand Down
62 changes: 36 additions & 26 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,40 @@
from typing import Callable, DefaultDict, Dict, List, Union

import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase


class BaseLogitsProcessor:

def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM

def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))

if len(input_ids) == 0:
self.init_state()
else:
if len(input_ids) > 0:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token)
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
if type(instruction) == Generate:
allowed_tokens = instruction.tokens
elif type(instruction) == Write:
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
Expand All @@ -62,6 +66,13 @@ def __call__(self, input_ids: List[int],

class RegexLogitsProcessor(BaseLogitsProcessor):

@classmethod
@lru_cache(maxsize=32)
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide(regex_string, tokenizer)

def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.
Expand All @@ -73,9 +84,8 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
super().__init__(
RegexLogitsProcessor._get_guide(regex_string, tokenizer))


class JSONLogitsProcessor(RegexLogitsProcessor):
Expand Down Expand Up @@ -115,6 +125,12 @@ def __init__(self, schema: Union[str, Dict, BaseModel],

class CFGLogitsProcessor(BaseLogitsProcessor):

@classmethod
@lru_cache(maxsize=32)
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)

def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.
Expand All @@ -126,17 +142,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm

def init_state(self):
"""Initialize state with a CFGFSM copy."""
super().init_state()
self.fsm = self.fsm.copy()
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
self._guide = self._guide.copy()


@lru_cache
@lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
Expand Down

0 comments on commit 7b0a0df

Please sign in to comment.