Skip to content

Commit

Permalink
Merge pull request #146 from theroyallab/tokenizer_data_fix
Browse files Browse the repository at this point in the history
Tokenizer data fix
  • Loading branch information
bdashore3 authored Jul 8, 2024
2 parents c7ce97f + e97ad9c commit 1743828
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
26 changes: 23 additions & 3 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
from functools import lru_cache


class OutlinesTokenizerWrapper:
Expand Down Expand Up @@ -51,6 +55,18 @@ def next(self):
return self.fsm.allowed_token_ids(self.state), set()


@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)


def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""

_get_lmfe_tokenizer_data.clear_cache()


class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""

Expand Down Expand Up @@ -82,7 +98,9 @@ def add_json_schema_filter(
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]

lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)

# Append the filters
Expand All @@ -107,7 +125,9 @@ def add_regex_filter(

return

lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)

# Append the filters
self.filters.append(lmfilter)
Expand Down
9 changes: 8 additions & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from loguru import logger
from typing import List, Optional, Union

from backends.exllamav2.grammar import ExLlamaV2Grammar
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
hardware_supports_flash_attn,
Expand Down Expand Up @@ -704,6 +707,10 @@ async def unload(self, loras_only: bool = False, **kwargs):
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))

# Delete references held in the grammar module
clear_grammar_func_cache()

# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
lora.unload()
Expand Down

0 comments on commit 1743828

Please sign in to comment.