Skip to content

Commit

Permalink
Model: Add support for Q4 cache
Browse files Browse the repository at this point in the history
Add this in addition to 8bit cache and 16bit cache. Passing "Q4" with
the cache_mode request parameter will set this on model load.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Mar 6, 2024
1 parent 0b25c20 commit 9a007c4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
27 changes: 24 additions & 3 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
from common.utils import coalesce, unwrap
from common.logger import init_logger

# Optional imports for dependencies
try:
from exllamav2 import ExLlamaV2Cache_Q4

_exllamav2_has_int4 = True
except ImportError:
_exllamav2_has_int4 = False

logger = init_logger(__name__)


Expand All @@ -46,7 +54,7 @@ class ExllamaV2Container:
active_loras: List[ExLlamaV2Lora] = []

# Internal config vars
cache_fp8: bool = False
cache_mode: str = "FP16"
use_cfg: bool = False

# GPU split vars
Expand Down Expand Up @@ -109,7 +117,15 @@ def progress(loaded_modules: int, total_modules: int,

self.quiet = quiet

self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
if cache_mode == "Q4" and not _exllamav2_has_int4:
logger.warning(
"Q4 cache is not available "
"in the currently installed ExllamaV2 version. Using FP16."
)
cache_mode = "FP16"

self.cache_mode = cache_mode

# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
Expand Down Expand Up @@ -398,7 +414,12 @@ def progress(loaded_modules: int, total_modules: int)
yield value

batch_size = 2 if self.use_cfg else 1
if self.cache_fp8:

if self.cache_mode == "Q4" and _exllamav2_has_int4:
self.cache = ExLlamaV2Cache_Q4(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
elif self.cache_mode == "FP8":
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def get_current_model():
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16",
cache_mode=MODEL_CONTAINER.cache_mode,
prompt_template=prompt_template.name if prompt_template else None,
num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token,
use_cfg=MODEL_CONTAINER.use_cfg,
Expand Down

0 comments on commit 9a007c4

Please sign in to comment.