Skip to content

Commit

Permalink
Add ExllamaV2 Inference Framework Support. (#2455)
Browse files Browse the repository at this point in the history
  • Loading branch information
leonxia1018 authored Oct 9, 2023
1 parent 9d27d68 commit 466da28
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 3 deletions.
61 changes: 61 additions & 0 deletions docs/exllamaV2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ExllamaV2 GPTQ Inference Franework

Integrated [ExllamaV2] (https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed.

**Note: Exllama not yet support embedding REST API.**

## Install ExllamaV2

Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details):

```bash
git clone https://github.com/turboderp/exllamav2
cd exllamav2
pip install -e .
```

Chat with the CLI:
```bash
python3 -m fastchat.serve.cli \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama
```

Start model worker:
```bash
# Download quantized model from huggingface
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g

# Load model with default configuration (max sequence length 4096, no GPU split setting).
python3 -m fastchat.serve.model_worker \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama

#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1.
python3 -m fastchat.serve.model_worker \
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \
--enable-exllama \
--exllama-max-seq-len 2048 \
--exllama-gpu-split 18,24
```

## Performance

Reference: https://github.com/turboderp/exllamav2#performance


| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 |
|------------|--------------|-------|-------|-----|------------|----------|------------|-------------|
| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s |
| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s |
| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s |
| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s |
| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s |
| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s |
| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s |
| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s |
| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s |
| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s |
| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s |
27 changes: 27 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
from fastchat.constants import CPU_ISA
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.modules.exllama import ExllamaConfig, load_exllama_model
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
)
Expand Down Expand Up @@ -155,6 +157,7 @@ def load_model(
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
revision: str = "main",
debug: bool = False,
):
Expand Down Expand Up @@ -279,6 +282,9 @@ def load_model(
else:
model.to(device)
return model, tokenizer
elif exllama_config:
model, tokenizer = load_exllama_model(model_path, exllama_config)
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None: # Overwrite dtype if it is provided in the arguments.
Expand Down Expand Up @@ -325,13 +331,17 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_falcon = "rwforcausallm" in model_type
is_codet5p = "codet5p" in model_type
is_peft = "peft" in model_type
is_exllama = "exllama" in model_type

if is_chatglm:
return generate_stream_chatglm
elif is_falcon:
return generate_stream_falcon
elif is_codet5p:
return generate_stream_codet5p
elif is_exllama:
return generate_stream_exllama

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
# according to the model_name available in this context. This ensures
Expand Down Expand Up @@ -453,6 +463,23 @@ def add_model_args(parser):
default=-1,
help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.",
)
parser.add_argument(
"--enable-exllama",
action="store_true",
help="Used for exllamabv2. Enable exllamaV2 inference framework.",
)
parser.add_argument(
"--exllama-max-seq-len",
type=int,
default=4096,
help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.",
)
parser.add_argument(
"--exllama-gpu-split",
type=str,
default=None,
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
)


def remove_parent_directory_name(model_path):
Expand Down
76 changes: 76 additions & 0 deletions fastchat/model/model_exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys
import torch
import gc
from typing import Dict


def generate_stream_exllama(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
try:
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
sys.exit(-1)

prompt = params["prompt"]

generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
settings = ExLlamaV2Sampler.Settings()

settings.temperature = float(params.get("temperature", 0.85))
settings.top_k = int(params.get("top_k", 50))
settings.top_p = float(params.get("top_p", 0.8))
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])

max_new_tokens = int(params.get("max_new_tokens", 256))

generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
echo = bool(params.get("echo", True))

input_ids = generator.tokenizer.encode(prompt)
prompt_tokens = input_ids.shape[-1]
generator.begin_stream(input_ids, settings)

generated_tokens = 0
if echo:
output = prompt
else:
output = ""
while True:
chunk, eos, _ = generator.stream()
output += chunk
generated_tokens += 1
if generated_tokens == max_new_tokens:
finish_reason = "length"
break
elif eos:
finish_reason = "length"
break
yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": None,
}

yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": finish_reason,
}
gc.collect()
46 changes: 46 additions & 0 deletions fastchat/modules/exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass, field
import sys


@dataclass
class ExllamaConfig:
max_seq_len: int
gpu_split: str = None


class ExllamaModel:
def __init__(self, exllama_model, exllama_cache):
self.model = exllama_model
self.cache = exllama_cache
self.config = self.model.config


def load_exllama_model(model_path, exllama_config: ExllamaConfig):
try:
from exllamav2 import (
ExLlamaV2Config,
ExLlamaV2Tokenizer,
ExLlamaV2,
ExLlamaV2Cache,
)
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
sys.exit(-1)

exllamav2_config = ExLlamaV2Config()
exllamav2_config.model_dir = model_path
exllamav2_config.prepare()
exllamav2_config.max_seq_len = exllama_config.max_seq_len

exllama_model = ExLlamaV2(exllamav2_config)
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)

split = None
if exllama_config.gpu_split:
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
exllama_model.load(split)

exllama_cache = ExLlamaV2Cache(exllama_model)
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)

return model, tokenizer
10 changes: 9 additions & 1 deletion fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fastchat.model.model_adapter import add_model_args
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.serve.inference import ChatIO, chat_loop
from fastchat.utils import str_to_torch_dtype

Expand Down Expand Up @@ -195,7 +196,13 @@ def main(args):
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus

if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
)
else:
exllama_config = None
if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
elif args.style == "rich":
Expand Down Expand Up @@ -230,6 +237,7 @@ def main(args):
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
),
exllama_config=exllama_config,
revision=args.revision,
judge_sent_end=args.judge_sent_end,
debug=args.debug,
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length


Expand Down Expand Up @@ -302,6 +303,7 @@ def chat_loop(
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
Expand All @@ -318,6 +320,7 @@ def chat_loop(
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
revision=revision,
debug=debug,
)
Expand Down
20 changes: 18 additions & 2 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
get_context_length,
str_to_torch_dtype,
)
from fastchat.modules.exllama import ExllamaConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length


worker_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -170,8 +172,12 @@ def get_status(self):

def count_token(self, params):
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)

try:
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
except TypeError:
input_echo_len = self.tokenizer.num_tokens(prompt)

ret = {
"count": input_echo_len,
Expand Down Expand Up @@ -201,6 +207,7 @@ def __init__(
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
Expand Down Expand Up @@ -228,6 +235,7 @@ def __init__(
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
)
self.device = device
if self.tokenizer.pad_token == None:
Expand Down Expand Up @@ -514,6 +522,13 @@ def create_model_worker():
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
)
else:
exllama_config = None

worker = ModelWorker(
args.controller_address,
Expand All @@ -531,6 +546,7 @@ def create_model_worker():
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
Expand Down
Loading

0 comments on commit 466da28

Please sign in to comment.