Skip to content

Commit

Permalink
Model: Add EBNF grammar support
Browse files Browse the repository at this point in the history
Using the Outlines library, add support to supply EBNF strings and
pass them to the library for parsing.

From there, a wrapper is created and a filter is passed to generation.

Replace with an in-house solution at some point that's more flexible.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 23, 2024
1 parent 8a22b53 commit cd21850
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 15 deletions.
120 changes: 105 additions & 15 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
import traceback
from common.logger import init_logger
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2Sampler
from exllamav2.generator.filters import ExLlamaV2Filter

# Temporary, remove once the exllama version is bumped
# TODO: Remove after new exllama version is released
try:
from exllamav2.generator.filters import ExLlamaV2PrefixFilter

_exllama_filter_available = True
except ImportError:
_exllama_filter_available = False

try:
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter

_lmformatenforcer_available = True
except ImportError:
_lmformatenforcer_available = False
logger = init_logger(__name__)


logger = init_logger(__name__)
class OutlinesTokenizerWrapper:
"""Wrapper for Outlines tokenizer"""

def __init__(self, tokenizer):
self.tokenizer = tokenizer
id_to_piece = self.tokenizer.get_id_to_piece_list()
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())

def convert_token_to_string(self, token):
return token

def decode(self, tokens):
s = ""
id_to_piece = self.tokenizer.get_id_to_piece_list()
for t in tokens:
s += id_to_piece[t]
return s


class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
"""Filter class for context-free grammar via outlines"""

def __init__(self, model, tokenizer, grammar):
from outlines.fsm.fsm import CFGFSM

super().__init__(model, tokenizer)

self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
self.state = self.fsm.first_state

def begin(self, prefix_str=""):
self.state = self.fsm.first_state

def feed(self, token):
self.state = self.fsm.next_state(self.state, token.item())

def next(self):
return self.fsm.allowed_token_ids(self.state), set()


class ExLlamaV2Grammar:
Expand All @@ -34,28 +72,80 @@ def add_json_schema_filter(
):
"""Adds an ExllamaV2 filter based on a JSON schema."""

# Check if the required dependencies can be imported
if not _exllama_filter_available:
logger.warning(
"ExllamaV2PrefixFilter is not available "
"in the currently installed ExllamaV2 version."
"in the currently installed ExllamaV2 version. "
"Skipping JSON schema parsing."
)

return

if not _lmformatenforcer_available:
# Import optional dependencies
try:
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
)
except ImportError:
logger.error(
"lmformatenforcer must be installed to parse a json schema.\n"
"Please run the following command: pip install lm-format-enforcer"
"Skipping JSON schema parsing because "
"lm-format-enforcer is not installed.\n"
"Please run the following command: "
"pip install lm-format-enforcer"
)

return

# Create the parser
schema_parser = JsonSchemaParser(json_schema)
try:
schema_parser = JsonSchemaParser(json_schema)
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the JSON schema couldn't be parsed. "
"Please read the above error for more information."
)

return

lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")

# Append the filters
gen_settings.filters += [lmfilter, prefix_filter]
gen_settings.filters.extend([lmfilter, prefix_filter])
gen_settings.filter_prefer_eos = True

def add_ebnf_filter(
self,
ebnf_string: str,
gen_settings: ExLlamaV2Sampler.Settings,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""
Add an EBNF grammar filter.
Possibly replace outlines with an in-house solution in the future.
"""

if not _exllama_filter_available:
logger.warning(
"filter_prefer_eos is not available "
"in the currently installed ExllamaV2 version. "
"Skipping EBNF parsing."
)

return

try:
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
except ImportError:
logger.error(
"Skipping EBNF parsing because Outlines is not installed.\n"
"Please run the following command: pip install outlines"
)

return

gen_settings.filters.append(ebnf_filter)
gen_settings.filter_prefer_eos = True
8 changes: 8 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def generate_gen(self, prompt: str, **kwargs):

# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
gen_settings.filters = []

# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
Expand All @@ -769,6 +770,13 @@ def generate_gen(self, prompt: str, **kwargs):
json_schema, gen_settings, self.model, self.tokenizer
)

# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(
grammar_string, gen_settings, self.model, self.tokenizer
)

# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
Expand Down
5 changes: 5 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("json_schema"),
)

grammar_string: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("grammar_string"),
)

# Aliased variables
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
Expand Down Expand Up @@ -266,6 +270,7 @@ def to_gen_params(self, **kwargs):
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
"json_schema": self.json_schema,
"grammar_string": self.grammar_string,
}

return {**gen_params, **kwargs}
Expand Down

0 comments on commit cd21850

Please sign in to comment.