Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch grammar backend to Formatron #252

Merged
merged 8 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 90 additions & 127 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,16 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
import typing
from functools import lru_cache
from typing import List


class OutlinesTokenizerWrapper:
"""Wrapper for Outlines tokenizer"""

def __init__(self, tokenizer):
self.tokenizer = tokenizer
id_to_piece = self.tokenizer.get_id_to_piece_list()
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())

def convert_token_to_string(self, token):
return token

def decode(self, tokens):
s = ""
id_to_piece = self.tokenizer.get_id_to_piece_list()
for t in tokens:
s += id_to_piece[t]
return s


class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
"""Filter class for context-free grammar via outlines"""

def __init__(self, model, tokenizer, grammar):
from outlines.fsm.fsm import CFGFSM

super().__init__(model, tokenizer)

self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
self.state = self.fsm.first_state

def begin(self, prefix_str=""):
self.state = self.fsm.first_state

def feed(self, token):
self.state = self.fsm.next_state(self.state, token.item())

def next(self):
return self.fsm.allowed_token_ids(self.state), set()

def use_background_worker(self):
return True


@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)


class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""

token_sequence: List[int]

def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []

def begin(self, prefix_str: str):
self.token_sequence = []

def feed(self, token):
self.token_sequence.append(int(token[0][0]))

def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []

def use_background_worker(self):
return True


def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""

_get_lmfe_tokenizer_data.cache_clear()
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
from formatron.schemas import json_schema
from loguru import logger


class ExLlamaV2Grammar:
Expand All @@ -117,15 +23,24 @@ def __init__(self):

def add_json_schema_filter(
self,
json_schema: dict,
schema: dict,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on a JSON schema."""

# Create the parser
try:
schema_parser = JsonSchemaParser(json_schema)
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"

# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception:
traceback.print_exc()
logger.error(
Expand All @@ -135,14 +50,10 @@ def add_json_schema_filter(

return

# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]

lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
lmfilter = _create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.extend([lmfilter, prefix_filter])
self.filters.append(lmfilter)

def add_regex_filter(
self,
Expand All @@ -154,7 +65,9 @@ def add_regex_filter(

# Create the parser
try:
pattern_parser = RegexParser(pattern)
# Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception:
traceback.print_exc()
logger.error(
Expand All @@ -164,32 +77,82 @@ def add_regex_filter(

return

lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
lmfilter = _create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.append(lmfilter)

def add_ebnf_filter(
def add_kbnf_filter(
self,
ebnf_string: str,
kbnf_string: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""
Add an EBNF grammar filter.
Possibly replace outlines with an in-house solution in the future.
"""
"""Adds an ExllamaV2 filter based on KBNF grammar."""

# Create the parser
try:
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
except ImportError:
# Validate KBNF and create formatter
f = FormatterBuilder()
f.append_line(
f"""{f.extractor(lambda nonterminal:
CFGExtractor(nonterminal, kbnf_string))}"""
)
except Exception:
logger.error(
"Skipping EBNF parsing because Outlines is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
"Skipping because the KBNF string couldn't be parsed. "
"Please read the above error for more information."
)

return

self.filters.append(ebnf_filter)
lmfilter = _create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.append(lmfilter)


class CFGExtractor(NonterminalExtractor):
"""Extractor class for KBNF context-free grammar"""

def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string

# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return "", input_str

@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)


@lru_cache(1)
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
"""Build and cache engine vocabulary on first grammar run"""

return create_engine_vocabulary(tokenizer)


def _create_formatter_filter(
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
) -> ExLlamaV2Filter:
"""
Create a formatter filter for the ExLlamaV2 engine.
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
with lru_cache enabled for engine vocabulary
"""

vocab = _create_cached_engine_vocabulary(tokenizer)
f = formatter_builder.build(
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
)
return FormatterFilter(model, tokenizer, f)


def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""

_create_cached_engine_vocabulary.cache_clear()
2 changes: 1 addition & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ async def generate_gen(
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)

# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
Expand Down
22 changes: 11 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"sse-starlette",
"packaging",
"tokenizers",
"lm-format-enforcer >= 0.9.6",
"formatron",
"kbnf>=0.4.1",
"aiofiles",
"aiohttp",
"async_lru",
Expand All @@ -53,7 +54,6 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines",
"infinity-emb",
"sentence-transformers",
]
Expand All @@ -70,12 +70,12 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+cu121.torch2.5.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+cu121.torch2.5.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
Expand All @@ -99,9 +99,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.5/exllamav2-0.2.5+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.6/exllamav2-0.2.6+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]

# MARK: Ruff options
Expand Down
Loading