Skip to content

Commit

Permalink
Relevant changes from lapp0/fix-vllm-group-generation
Browse files Browse the repository at this point in the history
  • Loading branch information
viktor-ferenczi committed Feb 5, 2024
1 parent 36f15ba commit 244c914
Showing 1 changed file with 68 additions and 44 deletions.
112 changes: 68 additions & 44 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Make vLLM compatible with Outlines' guided generation."""
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
from typing import Dict, List, Tuple

import torch

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.fsm import FSMState, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


Expand Down Expand Up @@ -39,7 +38,59 @@ def _patched_apply_logits_processors(
return logits


def adapt_tokenizer(tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition, we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class CachedRegexFSM(RegexFSM):
def __init__(self, regex_string: str, adapted_tokenizer):
super().__init__(regex_string, adapted_tokenizer)
self.state_cache: Dict[int, FSMState] = {}

def get_state_by_token_ids(self, input_ids: Tuple[int]) -> FSMState:
state_key = hash(input_ids)

if not input_ids:
self.state_cache[state_key] = FSMState(0)

elif state_key not in self.state_cache:
prev_state_key = hash(input_ids[:-1])
prev_state = self.state_cache[prev_state_key]

last_token = input_ids[-1]
new_state = self.next_state(prev_state, last_token)
self.state_cache[state_key] = new_state

return self.state_cache[state_key]


class RegexLogitsProcessor:
fsm_cache: Dict[str, CachedRegexFSM] = {}
adapted_tokenizer = None

def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.
Expand All @@ -51,58 +102,29 @@ def __init__(self, regex_string, llm):
An instance of `vllm.LLM`
"""
tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer)
cls = self.__class__

fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
if cls.adapted_tokenizer is None:
cls.adapted_tokenizer = adapt_tokenizer(llm.tokenizer.tokenizer)

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
fsm = self.fsm_cache.get(regex_string)
if fsm is None:
fsm = CachedRegexFSM(regex_string, cls.adapted_tokenizer)
self.fsm_cache[regex_string] = fsm

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)
self.fsm = fsm

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
def __call__(self, seq_id: int, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
state = self.fsm.get_state_by_token_ids(tuple(input_ids))
allowed_tokens = self.fsm.allowed_token_ids(state)

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, llm):
Expand All @@ -118,5 +140,7 @@ def __init__(self, schema, llm):
"""
if isinstance(schema, dict):
schema = json.dumps(schema)

regex_string = build_regex_from_object(schema)

super().__init__(regex_string, llm)

0 comments on commit 244c914

Please sign in to comment.