diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d472117e..acb0806d 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -10,6 +10,7 @@ from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, + ExLlamaV2CacheBase, ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, @@ -148,6 +149,8 @@ def progress(loaded_modules: int, total_modules: int, """ self.quiet = quiet + + # Get cache mode self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") # Turn off GPU split if the user is using 1 GPU @@ -577,8 +580,9 @@ def progress(loaded_modules: int, total_modules: int) logger.info("Loading draft model: " + self.draft_config.model_dir) # Draft uses the autosplit loader, so create a cache that reflects this + draft_cache_class = self.get_cache_class(self.draft_cache_mode) self.draft_cache = self.create_cache( - cache_mode=self.draft_cache_mode, + cache_class=draft_cache_class, autosplit=True, use_tp=False, ) @@ -600,6 +604,9 @@ def progress(loaded_modules: int, total_modules: int) if not self.quiet: logger.info("Loading model: " + self.config.model_dir) + # Get class of the model cache + cache_class = self.get_cache_class(self.cache_mode) + # Load model with manual split # Entrypoint for single GPU users if self.use_tp: @@ -608,6 +615,7 @@ def progress(loaded_modules: int, total_modules: int) for value in self.model.load_tp_gen( self.gpu_split, callback_gen=progress_callback, + expect_cache_base=cache_class, expect_cache_tokens=self.cache_size, ): if value: @@ -624,7 +632,7 @@ def progress(loaded_modules: int, total_modules: int) # Create the model cache self.cache = self.create_cache( - cache_mode=self.cache_mode, + cache_class=cache_class, autosplit=self.gpu_split_auto, use_tp=self.use_tp, ) @@ -646,35 +654,39 @@ def progress(loaded_modules: int, total_modules: int) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) - def create_cache(self, cache_mode: str, autosplit: bool, use_tp: bool): - if has_tp and use_tp: - if self.cache_mode != "FP16": - logger.warning( - "Tensor parallel does not currently allow for use of " - "a quantized K/V cache. Using the specialized TP cache." - ) + # TODO: Maybe make a wrapper class with an ID instead of a utility function + def get_cache_class(self, cache_mode: str): + """Utility function to get a cache class based on user preference.""" - return ExLlamaV2Cache_TP( - self.model, - max_seq_len=self.cache_size, - batch_size=1, - ) - - cache_type = ExLlamaV2Cache match cache_mode: case "Q4": - cache_type = ExLlamaV2Cache_Q4 + return ExLlamaV2Cache_Q4 case "Q6": - cache_type = ExLlamaV2Cache_Q6 + return ExLlamaV2Cache_Q6 case "Q8": - cache_type = ExLlamaV2Cache_Q8 + return ExLlamaV2Cache_Q8 + case _: + return ExLlamaV2Cache - return cache_type( - self.model, - max_seq_len=self.cache_size, - lazy=autosplit, - batch_size=1, - ) + def create_cache( + self, cache_class: ExLlamaV2CacheBase, autosplit: bool, use_tp: bool + ): + """Utility function to create a model cache.""" + + if has_tp and use_tp: + return ExLlamaV2Cache_TP( + self.model, + base=cache_class, + max_seq_len=self.cache_size, + batch_size=1, + ) + else: + return cache_class( + self.model, + max_seq_len=self.cache_size, + lazy=autosplit, + batch_size=1, + ) async def create_generator(self): try: