diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 3ad2f44..47c5ed5 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,110 +1,16 @@ import traceback -from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer -from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter -from lmformatenforcer import ( - JsonSchemaParser, - RegexParser, - TokenEnforcer, - CharacterLevelParser, -) -from lmformatenforcer.integrations.exllamav2 import ( - build_token_enforcer_tokenizer_data, -) -from loguru import logger -from typing import List +import typing from functools import lru_cache +from typing import List - -class OutlinesTokenizerWrapper: - """Wrapper for Outlines tokenizer""" - - def __init__(self, tokenizer): - self.tokenizer = tokenizer - id_to_piece = self.tokenizer.get_id_to_piece_list() - self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)} - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = id_to_piece[self.tokenizer.eos_token_id] - self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys()) - - def convert_token_to_string(self, token): - return token - - def decode(self, tokens): - s = "" - id_to_piece = self.tokenizer.get_id_to_piece_list() - for t in tokens: - s += id_to_piece[t] - return s - - -class ExLlamaV2EbnfFilter(ExLlamaV2Filter): - """Filter class for context-free grammar via outlines""" - - def __init__(self, model, tokenizer, grammar): - from outlines.fsm.fsm import CFGFSM - - super().__init__(model, tokenizer) - - self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer) - self.fsm = CFGFSM(grammar, self.wrapped_tokenizer) - self.state = self.fsm.first_state - - def begin(self, prefix_str=""): - self.state = self.fsm.first_state - - def feed(self, token): - self.state = self.fsm.next_state(self.state, token.item()) - - def next(self): - return self.fsm.allowed_token_ids(self.state), set() - - def use_background_worker(self): - return True - - -@lru_cache(10) -def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): - return build_token_enforcer_tokenizer_data(tokenizer) - - -class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter): - """Filter class for LMFE""" - - token_sequence: List[int] - - def __init__( - self, - model: ExLlamaV2, - tokenizer: ExLlamaV2Tokenizer, - character_level_parser: CharacterLevelParser, - ): - super().__init__(model, tokenizer) - tokenizer_data = _get_lmfe_tokenizer_data(tokenizer) - self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser) - self.token_sequence = [] - - def begin(self, prefix_str: str): - self.token_sequence = [] - - def feed(self, token): - self.token_sequence.append(int(token[0][0])) - - def next(self): - allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence) - if not hasattr(self, "allow_return_type_list"): - return set(allowed_tokens), set() - else: - return sorted(allowed_tokens), [] - - def use_background_worker(self): - return True - - -def clear_grammar_func_cache(): - """Flush tokenizer_data cache to avoid holding references to - tokenizers after unloading a model""" - - _get_lmfe_tokenizer_data.cache_clear() +import torch +from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer +from exllamav2.generator.filters import ExLlamaV2Filter +from formatron.extractor import NonterminalExtractor +from formatron.formatter import FormatterBuilder +from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary +from formatron.schemas import json_schema +from loguru import logger class ExLlamaV2Grammar: @@ -117,7 +23,7 @@ def __init__(self): def add_json_schema_filter( self, - json_schema: dict, + schema: dict, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): @@ -125,7 +31,16 @@ def add_json_schema_filter( # Create the parser try: - schema_parser = JsonSchemaParser(json_schema) + # Add fields required by formatron if not present + if "$id" not in schema: + schema["$id"] = "https://example.com/example.json" + if "$schema" not in schema: + schema["$schema"] = "http://json-schema.org/draft-07/schema#" + + # Validate schema and create formatter + schema = json_schema.create_schema(schema) + f = FormatterBuilder() + f.append_line(f"{f.json(schema)}") except Exception: traceback.print_exc() logger.error( @@ -135,14 +50,10 @@ def add_json_schema_filter( return - # Allow JSON objects or JSON arrays at the top level - json_prefixes = ["[", "{"] - - lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser) - prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) + lmfilter = _create_formatter_filter(model, tokenizer, f) # Append the filters - self.filters.extend([lmfilter, prefix_filter]) + self.filters.append(lmfilter) def add_regex_filter( self, @@ -154,7 +65,9 @@ def add_regex_filter( # Create the parser try: - pattern_parser = RegexParser(pattern) + # Validate regex and create formatter + f = FormatterBuilder() + f.append_line(f"{f.regex(pattern)}") except Exception: traceback.print_exc() logger.error( @@ -164,32 +77,82 @@ def add_regex_filter( return - lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser) + lmfilter = _create_formatter_filter(model, tokenizer, f) # Append the filters self.filters.append(lmfilter) - def add_ebnf_filter( + def add_kbnf_filter( self, - ebnf_string: str, + kbnf_string: str, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): - """ - Add an EBNF grammar filter. - Possibly replace outlines with an in-house solution in the future. - """ + """Adds an ExllamaV2 filter based on KBNF grammar.""" + # Create the parser try: - ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string) - except ImportError: + # Validate KBNF and create formatter + f = FormatterBuilder() + f.append_line( + f"""{f.extractor(lambda nonterminal: + CFGExtractor(nonterminal, kbnf_string))}""" + ) + except Exception: logger.error( - "Skipping EBNF parsing because Outlines is not installed.\n" - "Please run the following command in your environment " - "to install extra packages:\n" - "pip install -U .[extras]" + "Skipping because the KBNF string couldn't be parsed. " + "Please read the above error for more information." ) return - self.filters.append(ebnf_filter) + lmfilter = _create_formatter_filter(model, tokenizer, f) + + # Append the filters + self.filters.append(lmfilter) + + +class CFGExtractor(NonterminalExtractor): + """Extractor class for KBNF context-free grammar""" + + def __init__(self, nonterminal: str, kbnf_string: str): + super().__init__(nonterminal) + self.kbnf_string = kbnf_string + + # Return the entire input string as the extracted string + def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: + return "", input_str + + @property + def kbnf_definition(self) -> str: + return self.kbnf_string.replace("start", self.nonterminal) + + +@lru_cache(1) +def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer): + """Build and cache engine vocabulary on first grammar run""" + + return create_engine_vocabulary(tokenizer) + + +def _create_formatter_filter( + model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder +) -> ExLlamaV2Filter: + """ + Create a formatter filter for the ExLlamaV2 engine. + Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter + with lru_cache enabled for engine vocabulary + """ + + vocab = _create_cached_engine_vocabulary(tokenizer) + f = formatter_builder.build( + vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)) + ) + return FormatterFilter(model, tokenizer, f) + + +def clear_grammar_func_cache(): + """Flush tokenizer_data cache to avoid holding references to + tokenizers after unloading a model""" + + _create_cached_engine_vocabulary.cache_clear() diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ff11531..50cef42 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1194,7 +1194,7 @@ async def generate_gen( # Add EBNF filter if it exists grammar_string = unwrap(kwargs.get("grammar_string")) if grammar_string: - grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) + grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) # Set banned strings banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) diff --git a/pyproject.toml b/pyproject.toml index c0b31f3..021b6a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "sse-starlette", "packaging", "tokenizers", - "lm-format-enforcer >= 0.9.6", + "formatron", + "kbnf>=0.4.1", "aiofiles", "aiohttp", "async_lru", @@ -53,7 +54,6 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines", "infinity-emb", "sentence-transformers", ]