Skip to content

Commit

Permalink
Refactored exl2 method to add in more features supported by the exlla…
Browse files Browse the repository at this point in the history
…mav2 library
  • Loading branch information
psych0v0yager committed Mar 6, 2024
1 parent a62ff00 commit 14d0160
Showing 1 changed file with 92 additions and 17 deletions.
109 changes: 92 additions & 17 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import TYPE_CHECKING, Optional

import torch
Expand Down Expand Up @@ -76,43 +77,117 @@ def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:

def exl2(
model_path: str,
device: Optional[str] = None,
model_kwargs: dict = {},
device: Optional[torch.device] = None,
max_seq_len: Optional[int] = None,
scale_pos_emb: Optional[float] = None,
scale_alpha_value: Optional[float] = None,
no_flash_attn: Optional[bool] = None,
num_experts_per_token: Optional[int] = None,
cache_8bit: bool = False,
cache_q4: bool = False,
tokenizer_kwargs: dict = {},
):
gpu_split: Optional[str] = None,
low_mem: Optional[bool] = None,
verbose: Optional[bool] = None,
) -> ExLlamaV2Model:
"""
Load an ExLlamaV2 model.
Args:
model_path (str): Path to the model directory.
device (Optional[torch.device], optional): Device to load the model on. Defaults to None.
max_seq_len (Optional[int], optional): Maximum sequence length. Defaults to None.
scale_pos_emb (Optional[float], optional): Scale factor for positional embeddings. Defaults to None.
scale_alpha_value (Optional[float], optional): Scale alpha value. Defaults to None.
no_flash_attn (Optional[bool], optional): Disable flash attention. Defaults to None.
num_experts_per_token (Optional[int], optional): Number of experts per token. Defaults to None.
cache_8bit (bool, optional): Use 8-bit cache. Defaults to False.
cache_q4 (bool, optional): Use Q4 cache. Defaults to False.
tokenizer_kwargs (dict, optional): Additional keyword arguments for the tokenizer. Defaults to {}.
gpu_split (str): \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature
low_mem (bool, optional): Enable VRAM optimizations, potentially trading off speed
verbose (bool, optional): Enable if you want debugging statements
Returns:
ExLlamaV2Model: Loaded ExLlamaV2 model.
Raises:
ImportError: If the `exllamav2` library is not installed.
"""

try:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2 import ( # , ExLlamaV2Cache_Q4
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config,
)
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
)

if os.name != "nt":
use_fasttensors = True
else:
use_fasttensors = False

# Create config

config = ExLlamaV2Config()
config.model_dir = model_path
config.fasttensors = use_fasttensors
config.prepare()

config.max_seq_len = model_kwargs.pop("max_seq_len", config.max_seq_len)
config.scale_pos_emb = model_kwargs.pop("scale_pos_emb", config.scale_pos_emb)
config.scale_alpha_value = model_kwargs.pop(
"scale_alpha_value", config.scale_alpha_value
)
config.no_flash_attn = model_kwargs.pop("no_flash_attn", config.no_flash_attn)
config.num_experts_per_token = int(
model_kwargs.pop("num_experts_per_token", config.num_experts_per_token)
)
# Set config options

config.max_seq_len = max_seq_len
config.scale_pos_emb = scale_pos_emb
config.scale_alpha_value = scale_alpha_value
config.no_flash_attn = no_flash_attn
if num_experts_per_token:
config.num_experts_per_token = num_experts_per_token
if low_mem:
config.set_low_mem()

# Load the model

model = ExLlamaV2(config)

split = None
if "gpu_split" in model_kwargs.keys():
split = [float(alloc) for alloc in model_kwargs["gpu_split"].split(",")]
if gpu_split and gpu_split != "auto":
split = [float(alloc) for alloc in gpu_split.split(",")]

model.load(split)
if gpu_split != "auto":
if not verbose:
print(" -- Loading model...")
model.load(split)

# Load tokenizer

if not verbose:
print(" -- Loading tokenizer...")

# tokenizer = ExLlamaV2Tokenizer(config)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)
# tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs)

# Create cache

if cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded)
# elif cache_q4:
# cache = ExLlamaV2Cache_Q4(model, lazy = not model.loaded)
else:
cache = ExLlamaV2Cache(model, lazy=not model.loaded)

# Load model now if auto split enabled

cache = ExLlamaV2Cache(model)
if not model.loaded:
print(" -- Loading model...")
model.load_autosplit(cache)

return ExLlamaV2Model(model, tokenizer, device, cache)

0 comments on commit 14d0160

Please sign in to comment.