-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ExllamaV2 Inference Framework Support. (#2455)
- Loading branch information
1 parent
9d27d68
commit 466da28
Showing
8 changed files
with
249 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# ExllamaV2 GPTQ Inference Franework | ||
|
||
Integrated [ExllamaV2] (https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed. | ||
|
||
**Note: Exllama not yet support embedding REST API.** | ||
|
||
## Install ExllamaV2 | ||
|
||
Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details): | ||
|
||
```bash | ||
git clone https://github.com/turboderp/exllamav2 | ||
cd exllamav2 | ||
pip install -e . | ||
``` | ||
|
||
Chat with the CLI: | ||
```bash | ||
python3 -m fastchat.serve.cli \ | ||
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ | ||
--enable-exllama | ||
``` | ||
|
||
Start model worker: | ||
```bash | ||
# Download quantized model from huggingface | ||
# Make sure you have git-lfs installed (https://git-lfs.com) | ||
git lfs install | ||
git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g | ||
|
||
# Load model with default configuration (max sequence length 4096, no GPU split setting). | ||
python3 -m fastchat.serve.model_worker \ | ||
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ | ||
--enable-exllama | ||
|
||
#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1. | ||
python3 -m fastchat.serve.model_worker \ | ||
--model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ | ||
--enable-exllama \ | ||
--exllama-max-seq-len 2048 \ | ||
--exllama-gpu-split 18,24 | ||
``` | ||
|
||
## Performance | ||
|
||
Reference: https://github.com/turboderp/exllamav2#performance | ||
|
||
|
||
| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 | | ||
|------------|--------------|-------|-------|-----|------------|----------|------------|-------------| | ||
| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s | | ||
| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s | | ||
| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s | | ||
| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s | | ||
| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s | | ||
| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s | | ||
| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s | | ||
| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s | | ||
| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s | | ||
| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s | | ||
| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import sys | ||
import torch | ||
import gc | ||
from typing import Dict | ||
|
||
|
||
def generate_stream_exllama( | ||
model, | ||
tokenizer, | ||
params: Dict, | ||
device: str, | ||
context_len: int, | ||
stream_interval: int = 2, | ||
judge_sent_end: bool = False, | ||
): | ||
try: | ||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler | ||
except ImportError as e: | ||
print(f"Error: Failed to load Exllamav2. {e}") | ||
sys.exit(-1) | ||
|
||
prompt = params["prompt"] | ||
|
||
generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) | ||
settings = ExLlamaV2Sampler.Settings() | ||
|
||
settings.temperature = float(params.get("temperature", 0.85)) | ||
settings.top_k = int(params.get("top_k", 50)) | ||
settings.top_p = float(params.get("top_p", 0.8)) | ||
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) | ||
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) | ||
|
||
max_new_tokens = int(params.get("max_new_tokens", 256)) | ||
|
||
generator.set_stop_conditions(params.get("stop_token_ids", None) or []) | ||
echo = bool(params.get("echo", True)) | ||
|
||
input_ids = generator.tokenizer.encode(prompt) | ||
prompt_tokens = input_ids.shape[-1] | ||
generator.begin_stream(input_ids, settings) | ||
|
||
generated_tokens = 0 | ||
if echo: | ||
output = prompt | ||
else: | ||
output = "" | ||
while True: | ||
chunk, eos, _ = generator.stream() | ||
output += chunk | ||
generated_tokens += 1 | ||
if generated_tokens == max_new_tokens: | ||
finish_reason = "length" | ||
break | ||
elif eos: | ||
finish_reason = "length" | ||
break | ||
yield { | ||
"text": output, | ||
"usage": { | ||
"prompt_tokens": prompt_tokens, | ||
"completion_tokens": generated_tokens, | ||
"total_tokens": prompt_tokens + generated_tokens, | ||
}, | ||
"finish_reason": None, | ||
} | ||
|
||
yield { | ||
"text": output, | ||
"usage": { | ||
"prompt_tokens": prompt_tokens, | ||
"completion_tokens": generated_tokens, | ||
"total_tokens": prompt_tokens + generated_tokens, | ||
}, | ||
"finish_reason": finish_reason, | ||
} | ||
gc.collect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from dataclasses import dataclass, field | ||
import sys | ||
|
||
|
||
@dataclass | ||
class ExllamaConfig: | ||
max_seq_len: int | ||
gpu_split: str = None | ||
|
||
|
||
class ExllamaModel: | ||
def __init__(self, exllama_model, exllama_cache): | ||
self.model = exllama_model | ||
self.cache = exllama_cache | ||
self.config = self.model.config | ||
|
||
|
||
def load_exllama_model(model_path, exllama_config: ExllamaConfig): | ||
try: | ||
from exllamav2 import ( | ||
ExLlamaV2Config, | ||
ExLlamaV2Tokenizer, | ||
ExLlamaV2, | ||
ExLlamaV2Cache, | ||
) | ||
except ImportError as e: | ||
print(f"Error: Failed to load Exllamav2. {e}") | ||
sys.exit(-1) | ||
|
||
exllamav2_config = ExLlamaV2Config() | ||
exllamav2_config.model_dir = model_path | ||
exllamav2_config.prepare() | ||
exllamav2_config.max_seq_len = exllama_config.max_seq_len | ||
|
||
exllama_model = ExLlamaV2(exllamav2_config) | ||
tokenizer = ExLlamaV2Tokenizer(exllamav2_config) | ||
|
||
split = None | ||
if exllama_config.gpu_split: | ||
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] | ||
exllama_model.load(split) | ||
|
||
exllama_cache = ExLlamaV2Cache(exllama_model) | ||
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) | ||
|
||
return model, tokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.