Skip to content

Commit

Permalink
Grammar: Preliminary Formatron KBNF support
Browse files Browse the repository at this point in the history
  • Loading branch information
DocShotgun committed Nov 23, 2024
1 parent 0836a93 commit a9f39bc
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import traceback
import typing
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger
Expand All @@ -7,6 +8,7 @@
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.extractor import NonterminalExtractor


def clear_grammar_func_cache():
Expand Down Expand Up @@ -98,7 +100,11 @@ def add_kbnf_filter(
try:
# Validate KBNF and create formatter
f = FormatterBuilder()
# TODO: Implement this
f.append_line(
f"{f.extractor(
lambda nonterminal: CustomExtractor(nonterminal, kbnf_string)
)}"
)
except Exception:
logger.error(
"Skipping because the KBNF string couldn't be parsed. "
Expand All @@ -111,3 +117,19 @@ def add_kbnf_filter(

# Append the filters
self.filters.append(lmfilter)


class CustomExtractor(NonterminalExtractor):
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string

# Fails without an extract function defined
# No idea what it does or why it's needed, but this seems to work
# TODO: Figure out how to do this properly
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return input_str[len(input_str) :], input_str[: len(input_str)]

@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)

0 comments on commit a9f39bc

Please sign in to comment.