diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index fa3306a3..2fac9f4d 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -2,9 +2,13 @@ from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from lmformatenforcer import JsonSchemaParser, RegexParser -from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter +from lmformatenforcer.integrations.exllamav2 import ( + ExLlamaV2TokenEnforcerFilter, + build_token_enforcer_tokenizer_data, +) from loguru import logger from typing import List +from functools import lru_cache class OutlinesTokenizerWrapper: @@ -51,6 +55,18 @@ def next(self): return self.fsm.allowed_token_ids(self.state), set() +@lru_cache(10) +def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): + return build_token_enforcer_tokenizer_data(tokenizer) + + +def clear_grammar_func_cache(): + """Flush tokenizer_data cache to avoid holding references to + tokenizers after unloading a model""" + + _get_lmfe_tokenizer_data.clear_cache() + + class ExLlamaV2Grammar: """ExLlamaV2 class for various grammar filters/parsers.""" @@ -82,7 +98,9 @@ def add_json_schema_filter( # Allow JSON objects or JSON arrays at the top level json_prefixes = ["[", "{"] - lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer) + lmfilter = ExLlamaV2TokenEnforcerFilter( + schema_parser, _get_lmfe_tokenizer_data(tokenizer) + ) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) # Append the filters @@ -107,7 +125,9 @@ def add_regex_filter( return - lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer) + lmfilter = ExLlamaV2TokenEnforcerFilter( + pattern_parser, _get_lmfe_tokenizer_data(tokenizer) + ) # Append the filters self.filters.append(lmfilter) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5a425c7b..5f4e86b7 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -26,7 +26,10 @@ from loguru import logger from typing import List, Optional, Union -from backends.exllamav2.grammar import ExLlamaV2Grammar +from backends.exllamav2.grammar import ( + ExLlamaV2Grammar, + clear_grammar_func_cache, +) from backends.exllamav2.utils import ( exllama_disabled_flash_attn, hardware_supports_flash_attn, @@ -704,6 +707,10 @@ async def unload(self, loras_only: bool = False, **kwargs): # Wait for other jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) + # Delete references held in the grammar module + clear_grammar_func_cache() + + # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: lora.unload()