From 5002617eac97c7d1cdaa4518d6eb404619d4530d Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 16 Aug 2024 15:17:03 -0400 Subject: [PATCH] Model: Split cache creation into a common function Unifies the switch statement across both draft and model caches. Signed-off-by: kingbri --- backends/exllamav2/model.py | 92 ++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d9f2af69..0648dbc5 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -548,30 +548,11 @@ def progress(loaded_modules: int, total_modules: int) if not self.quiet: logger.info("Loading draft model: " + self.draft_config.model_dir) - if self.draft_cache_mode == "Q4": - self.draft_cache = ExLlamaV2Cache_Q4( - self.draft_model, - max_seq_len=self.cache_size, - lazy=True, - ) - elif self.draft_cache_mode == "Q6": - self.draft_cache = ExLlamaV2Cache_Q6( - self.draft_model, - max_seq_len=self.cache_size, - lazy=True, - ) - elif self.draft_cache_mode == "Q8": - self.draft_cache = ExLlamaV2Cache_Q8( - self.draft_model, - max_seq_len=self.cache_size, - lazy=True, - ) - else: - self.draft_cache = ExLlamaV2Cache( - self.draft_model, - max_seq_len=self.cache_size, - lazy=True, - ) + self.draft_cache = self.create_cache( + cache_mode=self.draft_cache_mode, + autosplit=True, + ) + for value in self.draft_model.load_autosplit_gen( self.draft_cache, reserve_vram=autosplit_reserve, @@ -601,34 +582,10 @@ def progress(loaded_modules: int, total_modules: int) if value: yield value - if self.cache_mode == "Q4": - self.cache = ExLlamaV2Cache_Q4( - self.model, - max_seq_len=self.cache_size, - lazy=self.gpu_split_auto, - batch_size=1, - ) - elif self.cache_mode == "Q6": - self.cache = ExLlamaV2Cache_Q6( - self.model, - max_seq_len=self.cache_size, - lazy=self.gpu_split_auto, - batch_size=1, - ) - elif self.cache_mode == "Q8": - self.cache = ExLlamaV2Cache_Q8( - self.model, - max_seq_len=self.cache_size, - lazy=self.gpu_split_auto, - batch_size=1, - ) - else: - self.cache = ExLlamaV2Cache( - self.model, - max_seq_len=self.cache_size, - lazy=self.gpu_split_auto, - batch_size=1, - ) + self.cache = self.create_cache( + cache_mode=self.cache_mode, + autosplit=self.gpu_split_auto, + ) # Load model with autosplit if self.gpu_split_auto: @@ -647,6 +604,37 @@ def progress(loaded_modules: int, total_modules: int) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) + def create_cache(self, cache_mode: str, autosplit: bool): + match cache_mode: + case "Q4": + return ExLlamaV2Cache_Q4( + self.model, + max_seq_len=self.cache_size, + lazy=autosplit, + batch_size=1, + ) + case "Q6": + return ExLlamaV2Cache_Q6( + self.model, + max_seq_len=self.cache_size, + lazy=self.gpu_split_auto, + batch_size=1, + ) + case "Q8": + return ExLlamaV2Cache_Q8( + self.model, + max_seq_len=self.cache_size, + lazy=autosplit, + batch_size=1, + ) + case _: + return ExLlamaV2Cache( + self.model, + max_seq_len=self.cache_size, + lazy=self.gpu_split_auto, + batch_size=1, + ) + async def create_generator(self): try: # Don't acquire locks unless a model is loaded