Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored exl2 method to add LoRA, 8bit cache, and other features supported by exllama #729

Merged
merged 26 commits into from
Mar 13, 2024
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
14d0160
Refactored exl2 method to add in more features supported by the exlla…
psych0v0yager Mar 6, 2024
a46d86a
Added LoRA support
psych0v0yager Mar 7, 2024
e0544f0
Added unloading as well
psych0v0yager Mar 7, 2024
f533907
fixed LoRA import
psych0v0yager Mar 7, 2024
49165f3
fixed LoRA import
psych0v0yager Mar 7, 2024
741251d
fixed LoRA import
psych0v0yager Mar 7, 2024
d5232d8
Made max_seq_len optional again
psych0v0yager Mar 7, 2024
5caa973
Made remaining params optional
psych0v0yager Mar 7, 2024
1be4858
Removed optional flag on device. Even before my changes it would cras…
psych0v0yager Mar 7, 2024
f75ff1f
Fixed type check
psych0v0yager Mar 7, 2024
e33d344
Fixed the input error
psych0v0yager Mar 7, 2024
be528af
4 bit cache support is now active
psych0v0yager Mar 8, 2024
bd984ad
Refactored exl2 method to add in more features supported by the exlla…
psych0v0yager Mar 6, 2024
aa99511
Added LoRA support
psych0v0yager Mar 7, 2024
eb0bbfb
Added unloading as well
psych0v0yager Mar 7, 2024
b7bd216
fixed LoRA import
psych0v0yager Mar 7, 2024
9b720a3
fixed LoRA import
psych0v0yager Mar 7, 2024
f7d5324
fixed LoRA import
psych0v0yager Mar 7, 2024
e62a70b
Made max_seq_len optional again
psych0v0yager Mar 7, 2024
29135ee
Made remaining params optional
psych0v0yager Mar 7, 2024
ec38d6e
Removed optional flag on device. Even before my changes it would cras…
psych0v0yager Mar 7, 2024
8d51cb5
Fixed type check
psych0v0yager Mar 7, 2024
f9b6e44
Fixed the input error
psych0v0yager Mar 7, 2024
2cf15c2
4 bit cache support is now active
psych0v0yager Mar 8, 2024
a8ab1e6
Made formatting changes
psych0v0yager Mar 11, 2024
7dec179
updated branch
psych0v0yager Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 125 additions & 27 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from typing import TYPE_CHECKING, Optional

import torch

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora
from transformers import PreTrainedTokenizer

from .transformers import TransformerTokenizer


class ExLlamaV2Model:
"""Represents a `exl2` model."""
Expand All @@ -18,12 +19,14 @@ def __init__(
tokenizer: "PreTrainedTokenizer",
device,
cache: "ExLlamaV2Cache",
lora: Optional["ExLlamaV2Lora"] = None,
):
self.device = device
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)
self.cache = cache
self.past_seq = None
self.lora = lora

def forward(self, input_ids: torch.LongTensor, *_):
"""Compute a forward pass through the exl2 model."""
Expand All @@ -50,6 +53,7 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[longest_prefix:-1].view(1, -1),
self.cache,
preprocess_only=True,
loras=[self.lora],
)
elif seq_tensor.shape[0] == longest_prefix:
self.cache.current_seq_len -= 1
Expand All @@ -61,58 +65,152 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[:-1].view(1, -1),
self.cache,
preprocess_only=True,
loras=[self.lora],
)

self.past_seq = seq_tensor

return self.model.forward(seq_tensor[-1:].view(1, -1), self.cache)
return self.model.forward(
seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora]
)

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
logits = self.forward(input_ids)
next_token_logits = logits[..., -1, :]

return next_token_logits, None

def update_lora(self, lora_path: Optional[str] = None):
"""
Update and apply the LoRA to the model.

Args:
lora_path (Optional[str]): The path to the LoRA directory. If None, the LoRA will be unloaded.
"""
try:
from exllamav2 import ExLlamaV2Lora
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
)
if lora_path is None:
if self.lora is not None:
print(" -- Unloading LoRA...")
self.lora = None
else:
self.lora = ExLlamaV2Lora.from_directory(self.model, lora_path)
print(" -- Loading LoRA...")


def exl2(
model_path: str,
device: Optional[str] = None,
model_kwargs: dict = {},
device: str,
max_seq_len: Optional[int] = None,
scale_pos_emb: Optional[float] = None,
scale_alpha_value: Optional[float] = None,
no_flash_attn: Optional[bool] = None,
num_experts_per_token: Optional[int] = None,
cache_8bit: bool = False,
cache_q4: bool = False,
tokenizer_kwargs: dict = {},
):
gpu_split: Optional[str] = None,
low_mem: Optional[bool] = None,
verbose: Optional[bool] = None,
) -> ExLlamaV2Model:
"""
Load an ExLlamaV2 model.

Args:
model_path (str): Path to the model directory.
device (str): Device to load the model on. Pass in 'cuda' for GPU or 'cpu' for CPU
max_seq_len (Optional[int], optional): Maximum sequence length. Defaults to None.
scale_pos_emb (Optional[float], optional): Scale factor for positional embeddings. Defaults to None.
scale_alpha_value (Optional[float], optional): Scale alpha value. Defaults to None.
no_flash_attn (Optional[bool], optional): Disable flash attention. Defaults to None.
num_experts_per_token (Optional[int], optional): Number of experts per token. Defaults to None.
cache_8bit (bool, optional): Use 8-bit cache. Defaults to False.
cache_q4 (bool, optional): Use Q4 cache. Defaults to False.
tokenizer_kwargs (dict, optional): Additional keyword arguments for the tokenizer. Defaults to {}.
gpu_split (str): \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature
low_mem (bool, optional): Enable VRAM optimizations, potentially trading off speed
verbose (bool, optional): Enable if you want debugging statements

Returns:
ExLlamaV2Model: Loaded ExLlamaV2 model.

Raises:
ImportError: If the `exllamav2` library is not installed.
"""

try:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Config,
)
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
)

# Load tokenizer
if not verbose:
print(" -- Loading tokenizer...")
tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)
# tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs)

# Check fasttensors for config
if os.name != "nt":
use_fasttensors = True
else:
use_fasttensors = False

# Create config
config = ExLlamaV2Config()
config.model_dir = model_path
config.fasttensors = use_fasttensors
rlouf marked this conversation as resolved.
Show resolved Hide resolved
config.prepare()

config.max_seq_len = model_kwargs.pop("max_seq_len", config.max_seq_len)
config.scale_pos_emb = model_kwargs.pop("scale_pos_emb", config.scale_pos_emb)
config.scale_alpha_value = model_kwargs.pop(
"scale_alpha_value", config.scale_alpha_value
)
config.no_flash_attn = model_kwargs.pop("no_flash_attn", config.no_flash_attn)
config.num_experts_per_token = int(
model_kwargs.pop("num_experts_per_token", config.num_experts_per_token)
)

# Set config options
if max_seq_len is not None:
rlouf marked this conversation as resolved.
Show resolved Hide resolved
config.max_seq_len = max_seq_len
if scale_pos_emb is not None:
config.scale_pos_emb = scale_pos_emb
if scale_alpha_value is not None:
config.scale_alpha_value = scale_alpha_value
if no_flash_attn is not None:
config.no_flash_attn = no_flash_attn
if num_experts_per_token is not None:
config.num_experts_per_token = num_experts_per_token
if low_mem:
config.set_low_mem()

# Prepare the model from the config
model = ExLlamaV2(config)

split = None
if "gpu_split" in model_kwargs.keys():
split = [float(alloc) for alloc in model_kwargs["gpu_split"].split(",")]

model.load(split)
# Create cache
if cache_8bit:
rlouf marked this conversation as resolved.
Show resolved Hide resolved
cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded)
elif cache_q4:
cache = ExLlamaV2Cache_Q4(model, lazy=not model.loaded)
else:
cache = ExLlamaV2Cache(model, lazy=not model.loaded)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)

cache = ExLlamaV2Cache(model)
# Load the model
split = None
if gpu_split and gpu_split != "auto":
split = [float(alloc) for alloc in gpu_split.split(",")]
if not verbose:
print(" -- Loading model...")
model.load(split)

# Autoload if no GPU split was provided
if not model.loaded:
print(" -- Loading model...")
model.load_autosplit(cache)

return ExLlamaV2Model(model, tokenizer, device, cache)
Loading