Skip to content

Commit

Permalink
Merge pull request #201 from theroyallab/lmfe_fix
Browse files Browse the repository at this point in the history
Fix LMFE
  • Loading branch information
kingbri1 authored Sep 15, 2024
2 parents a2b4e3f + 318c425 commit 2d22183
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 25 deletions.
53 changes: 45 additions & 8 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger
Expand Down Expand Up @@ -54,12 +58,48 @@ def feed(self, token):
def next(self):
return self.fsm.allowed_token_ids(self.state), set()

def use_background_worker(self):
return True


@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)


class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""

token_sequence: List[int]

def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []

def begin(self, prefix_str: str):
self.token_sequence = []

def feed(self, token):
self.token_sequence.append(int(token[0][0]))

def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []

def use_background_worker(self):
return True


def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
Expand Down Expand Up @@ -98,9 +138,7 @@ def add_json_schema_filter(
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]

lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)

# Append the filters
Expand All @@ -109,6 +147,7 @@ def add_json_schema_filter(
def add_regex_filter(
self,
pattern: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
Expand All @@ -125,9 +164,7 @@ def add_regex_filter(

return

lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)

# Append the filters
self.filters.append(lmfilter)
Expand Down
2 changes: 1 addition & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ async def generate_gen(
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)

# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
Expand Down
2 changes: 1 addition & 1 deletion backends/exllamav2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def check_exllama_version():
"""Verifies the exllama version"""

required_version = version.parse("0.2.1")
required_version = version.parse("0.2.2")
current_version = version.parse(package_version("exllamav2").split("+")[0])

unsupported_message = (
Expand Down
30 changes: 15 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
Expand All @@ -95,12 +95,12 @@ cu118 = [
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
Expand All @@ -119,9 +119,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]

# MARK: Ruff options
Expand Down

0 comments on commit 2d22183

Please sign in to comment.