diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 9c27491..9621e3f 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,4 +1,5 @@ import traceback +import typing from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter from loguru import logger @@ -7,6 +8,7 @@ from formatron.formatter import FormatterBuilder from formatron.schemas import json_schema from formatron.integrations.exllamav2 import create_formatter_filter +from formatron.extractor import NonterminalExtractor def clear_grammar_func_cache(): @@ -98,7 +100,11 @@ def add_kbnf_filter( try: # Validate KBNF and create formatter f = FormatterBuilder() - # TODO: Implement this + f.append_line( + f"{f.extractor( + lambda nonterminal: CustomExtractor(nonterminal, kbnf_string) + )}" + ) except Exception: logger.error( "Skipping because the KBNF string couldn't be parsed. " @@ -111,3 +117,19 @@ def add_kbnf_filter( # Append the filters self.filters.append(lmfilter) + + +class CustomExtractor(NonterminalExtractor): + def __init__(self, nonterminal: str, kbnf_string: str): + super().__init__(nonterminal) + self.kbnf_string = kbnf_string + + # Fails without an extract function defined + # No idea what it does or why it's needed, but this seems to work + # TODO: Figure out how to do this properly + def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: + return input_str[len(input_str) :], input_str[: len(input_str)] + + @property + def kbnf_definition(self) -> str: + return self.kbnf_string.replace("start", self.nonterminal)