From 577b6c79ac7ff5e02e501271934e47cca34a7f73 Mon Sep 17 00:00:00 2001 From: leonxia1018 Date: Sat, 7 Oct 2023 18:52:15 +0800 Subject: [PATCH] Add Exllama2 inference framework support. --- docs/exllamaV2.md | 61 ++++++++++++++++++++++ fastchat/model/model_adapter.py | 27 ++++++++++ fastchat/model/model_exllama.py | 76 ++++++++++++++++++++++++++++ fastchat/modules/exllama.py | 46 +++++++++++++++++ fastchat/serve/cli.py | 10 +++- fastchat/serve/inference.py | 3 ++ fastchat/serve/model_worker.py | 20 +++++++- fastchat/serve/multi_model_worker.py | 9 ++++ 8 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 docs/exllamaV2.md create mode 100644 fastchat/model/model_exllama.py create mode 100644 fastchat/modules/exllama.py diff --git a/docs/exllamaV2.md b/docs/exllamaV2.md new file mode 100644 index 000000000..b4df2dee2 --- /dev/null +++ b/docs/exllamaV2.md @@ -0,0 +1,61 @@ +# ExllamaV2 GPTQ Inference Franework + +Integrated [ExllamaV2] (https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed. + +**Note: Exllama not yet support embedding REST API.** + +## Install ExllamaV2 + +Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details): + +```bash +git clone https://github.com/turboderp/exllamav2 +cd exllamav2 +pip install -e . +``` + +Chat with the CLI: +```bash +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama +``` + +Start model worker: +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g + +# Load model with default configuration (max sequence length 4096, no GPU split setting). +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama + +#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1. +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama \ + --exllama-max-seq-len 2048 \ + --exllama-gpu-split 18,24 +``` + +## Performance + +Reference: https://github.com/turboderp/exllamav2#performance + + +| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 | +|------------|--------------|-------|-------|-----|------------|----------|------------|-------------| +| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s | +| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s | +| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s | +| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s | +| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s | +| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s | +| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s | +| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s | +| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s | +| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s | +| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s | \ No newline at end of file diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index d2ac56f8d..76bd7b245 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -29,12 +29,14 @@ from fastchat.constants import CPU_ISA from fastchat.modules.gptq import GptqConfig, load_gptq_quantized from fastchat.modules.awq import AWQConfig, load_awq_quantized +from fastchat.modules.exllama import ExllamaConfig, load_exllama_model from fastchat.conversation import Conversation, get_conv_template from fastchat.model.compression import load_compress_model from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense from fastchat.model.model_chatglm import generate_stream_chatglm from fastchat.model.model_codet5p import generate_stream_codet5p from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_exllama import generate_stream_exllama from fastchat.model.monkey_patch_non_inplace import ( replace_llama_attn_with_non_inplace_operations, ) @@ -155,6 +157,7 @@ def load_model( cpu_offloading: bool = False, gptq_config: Optional[GptqConfig] = None, awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, revision: str = "main", debug: bool = False, ): @@ -279,6 +282,9 @@ def load_model( else: model.to(device) return model, tokenizer + elif exllama_config: + model, tokenizer = load_exllama_model(model_path, exllama_config) + return model, tokenizer kwargs["revision"] = revision if dtype is not None: # Overwrite dtype if it is provided in the arguments. @@ -325,6 +331,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): is_falcon = "rwforcausallm" in model_type is_codet5p = "codet5p" in model_type is_peft = "peft" in model_type + is_exllama = "exllama" in model_type if is_chatglm: return generate_stream_chatglm @@ -332,6 +339,9 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): return generate_stream_falcon elif is_codet5p: return generate_stream_codet5p + elif is_exllama: + return generate_stream_exllama + elif peft_share_base_weights and is_peft: # Return a curried stream function that loads the right adapter # according to the model_name available in this context. This ensures @@ -453,6 +463,23 @@ def add_model_args(parser): default=-1, help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", ) + parser.add_argument( + "--enable-exllama", + action="store_true", + help="Used for exllamabv2. Enable exllamaV2 inference framework.", + ) + parser.add_argument( + "--exllama-max-seq-len", + type=int, + default=4096, + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", + ) + parser.add_argument( + "--exllama-gpu-split", + type=str, + default=None, + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", + ) def remove_parent_directory_name(model_path): diff --git a/fastchat/model/model_exllama.py b/fastchat/model/model_exllama.py new file mode 100644 index 000000000..d0cba38b4 --- /dev/null +++ b/fastchat/model/model_exllama.py @@ -0,0 +1,76 @@ +import sys +import torch +import gc +from typing import Dict + + +def generate_stream_exllama( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + try: + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + prompt = params["prompt"] + + generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) + settings = ExLlamaV2Sampler.Settings() + + settings.temperature = float(params.get("temperature", 0.85)) + settings.top_k = int(params.get("top_k", 50)) + settings.top_p = float(params.get("top_p", 0.8)) + settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + + max_new_tokens = int(params.get("max_new_tokens", 256)) + + generator.set_stop_conditions(params.get("stop_token_ids", None) or []) + echo = bool(params.get("echo", True)) + + input_ids = generator.tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + generator.begin_stream(input_ids, settings) + + generated_tokens = 0 + if echo: + output = prompt + else: + output = "" + while True: + chunk, eos, _ = generator.stream() + output += chunk + generated_tokens += 1 + if generated_tokens == max_new_tokens: + finish_reason = "length" + break + elif eos: + finish_reason = "length" + break + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": None, + } + + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/fastchat/modules/exllama.py b/fastchat/modules/exllama.py new file mode 100644 index 000000000..5bddaa91d --- /dev/null +++ b/fastchat/modules/exllama.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, field +import sys + + +@dataclass +class ExllamaConfig: + max_seq_len: int + gpu_split: str = None + + +class ExllamaModel: + def __init__(self, exllama_model, exllama_cache): + self.model = exllama_model + self.cache = exllama_cache + self.config = self.model.config + + +def load_exllama_model(model_path, exllama_config: ExllamaConfig): + try: + from exllamav2 import ( + ExLlamaV2Config, + ExLlamaV2Tokenizer, + ExLlamaV2, + ExLlamaV2Cache, + ) + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + exllamav2_config = ExLlamaV2Config() + exllamav2_config.model_dir = model_path + exllamav2_config.prepare() + exllamav2_config.max_seq_len = exllama_config.max_seq_len + + exllama_model = ExLlamaV2(exllamav2_config) + tokenizer = ExLlamaV2Tokenizer(exllamav2_config) + + split = None + if exllama_config.gpu_split: + split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] + exllama_model.load(split) + + exllama_cache = ExLlamaV2Cache(exllama_model) + model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) + + return model, tokenizer diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py index de52a44bd..716869db9 100644 --- a/fastchat/serve/cli.py +++ b/fastchat/serve/cli.py @@ -31,6 +31,7 @@ from fastchat.model.model_adapter import add_model_args from fastchat.modules.gptq import GptqConfig from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig from fastchat.serve.inference import ChatIO, chat_loop from fastchat.utils import str_to_torch_dtype @@ -195,7 +196,13 @@ def main(args): ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus os.environ["XPU_VISIBLE_DEVICES"] = args.gpus - + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None if args.style == "simple": chatio = SimpleChatIO(args.multiline) elif args.style == "rich": @@ -230,6 +237,7 @@ def main(args): wbits=args.awq_wbits, groupsize=args.awq_groupsize, ), + exllama_config=exllama_config, revision=args.revision, judge_sent_end=args.judge_sent_end, debug=args.debug, diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py index 169f086b9..b398adda1 100644 --- a/fastchat/serve/inference.py +++ b/fastchat/serve/inference.py @@ -37,6 +37,7 @@ ) from fastchat.modules.gptq import GptqConfig from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length @@ -302,6 +303,7 @@ def chat_loop( chatio: ChatIO, gptq_config: Optional[GptqConfig] = None, awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, revision: str = "main", judge_sent_end: bool = True, debug: bool = True, @@ -318,6 +320,7 @@ def chat_loop( cpu_offloading=cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, + exllama_config=exllama_config, revision=revision, debug=debug, ) diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 54d51cfd0..1cf999a3d 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -53,6 +53,8 @@ get_context_length, str_to_torch_dtype, ) +from fastchat.modules.exllama import ExllamaConfig +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length worker_id = str(uuid.uuid4())[:8] @@ -170,8 +172,12 @@ def get_status(self): def count_token(self, params): prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - input_echo_len = len(input_ids) + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) ret = { "count": input_echo_len, @@ -201,6 +207,7 @@ def __init__( cpu_offloading: bool = False, gptq_config: Optional[GptqConfig] = None, awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, stream_interval: int = 2, conv_template: Optional[str] = None, embed_in_truncate: bool = False, @@ -228,6 +235,7 @@ def __init__( cpu_offloading=cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, + exllama_config=exllama_config, ) self.device = device if self.tokenizer.pad_token == None: @@ -514,6 +522,13 @@ def create_model_worker(): wbits=args.awq_wbits, groupsize=args.awq_groupsize, ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None worker = ModelWorker( args.controller_address, @@ -531,6 +546,7 @@ def create_model_worker(): cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, + exllama_config=exllama_config, stream_interval=args.stream_interval, conv_template=args.conv_template, embed_in_truncate=args.embed_in_truncate, diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py index 13872bbdd..823378687 100644 --- a/fastchat/serve/multi_model_worker.py +++ b/fastchat/serve/multi_model_worker.py @@ -54,6 +54,7 @@ from fastchat.model.model_falcon import generate_stream_falcon from fastchat.model.model_codet5p import generate_stream_codet5p from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig from fastchat.serve.inference import generate_stream from fastchat.serve.model_worker import ModelWorker, worker_id, logger from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length @@ -204,6 +205,13 @@ def create_multi_model_worker(): groupsize=args.gptq_groupsize, act_order=args.gptq_act_order, ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + ) + else: + exllama_config = None if args.model_names is None: args.model_names = [[x.split("/")[-1]] for x in args.model_path] @@ -232,6 +240,7 @@ def create_multi_model_worker(): load_8bit=args.load_8bit, cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, + exllama_config=exllama_config, stream_interval=args.stream_interval, conv_template=conv_template, )