diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 43b86334..f656f7f2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -27,6 +27,14 @@ from common.utils import coalesce, unwrap from common.logger import init_logger +# Optional imports for dependencies +try: + from exllamav2 import ExLlamaV2Cache_Q4 + + _exllamav2_has_int4 = True +except ImportError: + _exllamav2_has_int4 = False + logger = init_logger(__name__) @@ -46,7 +54,7 @@ class ExllamaV2Container: active_loras: List[ExLlamaV2Lora] = [] # Internal config vars - cache_fp8: bool = False + cache_mode: str = "FP16" use_cfg: bool = False # GPU split vars @@ -109,7 +117,15 @@ def progress(loaded_modules: int, total_modules: int, self.quiet = quiet - self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" + cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") + if cache_mode == "Q4" and not _exllamav2_has_int4: + logger.warning( + "Q4 cache is not available " + "in the currently installed ExllamaV2 version. Using FP16." + ) + cache_mode = "FP16" + + self.cache_mode = cache_mode # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() @@ -398,7 +414,12 @@ def progress(loaded_modules: int, total_modules: int) yield value batch_size = 2 if self.use_cfg else 1 - if self.cache_fp8: + + if self.cache_mode == "Q4" and _exllamav2_has_int4: + self.cache = ExLlamaV2Cache_Q4( + self.model, lazy=self.gpu_split_auto, batch_size=batch_size + ) + elif self.cache_mode == "FP8": self.cache = ExLlamaV2Cache_8bit( self.model, lazy=self.gpu_split_auto, batch_size=batch_size ) diff --git a/main.py b/main.py index 38480a59..c866c82f 100644 --- a/main.py +++ b/main.py @@ -149,7 +149,7 @@ async def get_current_model(): rope_scale=MODEL_CONTAINER.config.scale_pos_emb, rope_alpha=MODEL_CONTAINER.config.scale_alpha_value, max_seq_len=MODEL_CONTAINER.config.max_seq_len, - cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16", + cache_mode=MODEL_CONTAINER.cache_mode, prompt_template=prompt_template.name if prompt_template else None, num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token, use_cfg=MODEL_CONTAINER.use_cfg,