From 7100cfd6cb0f846735cd4c5d55ddb6fcb86b5399 Mon Sep 17 00:00:00 2001 From: Huanghe Date: Fri, 23 Aug 2024 01:38:35 -0500 Subject: [PATCH] Update vllm's lm-format-enforcer benchmark --- benchmarks/exllamav2_json.py | 16 +++--- benchmarks/readme.md | 17 ++++--- benchmarks/transformers_json.py | 7 +++ benchmarks/vllm_json.py | 90 +++++++++++++++++++++++---------- benchmarks/vllm_json_bench.txt | 36 ++++++++----- 5 files changed, 113 insertions(+), 53 deletions(-) diff --git a/benchmarks/exllamav2_json.py b/benchmarks/exllamav2_json.py index f48fa300..c52ce141 100644 --- a/benchmarks/exllamav2_json.py +++ b/benchmarks/exllamav2_json.py @@ -7,7 +7,7 @@ from formatron.integrations.exllamav2 import create_formatter_filter from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter -from benchmarks.utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \ +from utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \ order_lfe from test_grammar_gen import LinkedList from utils import Address, BenchResult, Context, log @@ -15,10 +15,10 @@ def create_exllamav2_6bpw_llama3_8b(): - model_dir = "../tests/local_assets/Llama-3-8B-exl2/" + model_dir = "../tests/local_assets/Meta-Llama-3-8B-Instruct-32k/" config = ExLlamaV2Config(model_dir) model = ExLlamaV2(config) - cache = ExLlamaV2Cache(model, max_seq_len=65536, lazy=True) + cache = ExLlamaV2Cache(model, max_seq_len=4096, lazy=True) model.load_autosplit(cache, progress=True) tokenizer = ExLlamaV2Tokenizer(config) generator = ExLlamaV2DynamicGenerator( @@ -29,10 +29,10 @@ def create_exllamav2_6bpw_llama3_8b(): return generator def create_exllamav2_4bpw_llama2_7b(): - model_dir = "../tests/local_assets/Llama-2-7b-chat-hf-4.0-bpw-exl2/" + model_dir = "../tests/local_assets/LLaMA-2-7B-32K/" config = ExLlamaV2Config(model_dir) model = ExLlamaV2(config) - cache = ExLlamaV2Cache(model, max_seq_len=65536, lazy=True) + cache = ExLlamaV2Cache(model, max_seq_len=4096, lazy=True) model.load_autosplit(cache, progress=True) tokenizer = ExLlamaV2Tokenizer(config) generator = ExLlamaV2DynamicGenerator( @@ -122,6 +122,10 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): context = Context(0, 0) with open("exllamav2_json.txt", "w") as f: generator = create_exllamav2_6bpw_llama3_8b() + settings = ExLlamaV2Sampler.Settings() + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + generator.generate("Something", max_new_tokens=4080, gen_settings=settings) # warm up exllamav2 itself + settings = ExLlamaV2Sampler.Settings() system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -132,7 +136,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): # -------------------------------------------------------------------------------------------------------------- inputs = load_address() context.filters = [f_get_address_filter()] - max_new_tokens = 100 + max_new_tokens = 50 bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_address_json_exllamav2", f) context.filters = [lfe_get_address_filter()] bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2", f) diff --git a/benchmarks/readme.md b/benchmarks/readme.md index 74d389b5..9d7f3f4a 100644 --- a/benchmarks/readme.md +++ b/benchmarks/readme.md @@ -20,16 +20,17 @@ schema creation to the first run ends. ## vllm Default vllm setting are used. -| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms | -|-----------------|-----------------|---------------------------------|---------------------|-------------------------| -| Llama3-8B(bf16) | address_json | 40.82 | 41.94 | 0.65 | -| Llama3-8B(bf16) | linkedlist_json | 40.56 | 41.85 | 0.76 | -| Llama3-8B(bf16) | order_json | 40.05 | 41.46 | 0.84 | -| Llama2-7B(fp16) | address_json | 46.57 | 47.53 | 0.44 | -| Llama2-7B(fp16) | linkedlist_json | 46.51 | 47.54 | 0.46 | -| Llama2-7B(fp16) | order_json | 45.71 | 46.68 | 0.46 | +| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms | +|-----------------|-----------------|-------------------------------------------------|----------------------------------------------------------| +| Llama3-8B(bf16) | address_json | 0.59 | 2.31 | +| Llama3-8B(bf16) | linkedlist_json | 0.66 | 0.26 | +| Llama3-8B(bf16) | order_json | 0.64 | 0.92 | +| Llama2-7B(fp16) | address_json | 0.33 | 0.33 | +| Llama2-7B(fp16) | linkedlist_json | 0.45 | 0.36 | +| Llama2-7B(fp16) | order_json | 0.40 | 0.34 | ## Exllamav2 Default exllamav2 setting are used. +Quantization likely has some influence on json outputs and hence affects the performance. | model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms | |------------------------|-----------------|-------------------------------------------------|----------------------------------------------------------| diff --git a/benchmarks/transformers_json.py b/benchmarks/transformers_json.py index 1a9988fd..2fbc4aa3 100644 --- a/benchmarks/transformers_json.py +++ b/benchmarks/transformers_json.py @@ -108,6 +108,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" model, tokenizer = get_llama3_8b_tokenizer_and_model() model.eval() + # ---------------------------------------------------------------------------------------------------------- max_new_tokens = 50 inputs = load_address() prefix_fn = None @@ -116,6 +117,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): logits_processor = None prefix_fn = lfe_address_prefix() bench(data, context, execute, "lm_format_enforcer_llama3_8b_address_json", f) + # ---------------------------------------------------------------------------------------------------------- inputs = load_linkedlist() logits_processor = get_linkedlist_schema() max_new_tokens = 200 @@ -123,12 +125,14 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): logits_processor = None prefix_fn = lfe_linkedlist_prefix() bench(data, context, execute, "lm_format_enforcer_llama3_8b_linkedlist_json", f) + # ---------------------------------------------------------------------------------------------------------- inputs = load_orders() logits_processor = get_order_schema() bench(data, context, execute, "formatron_llama3_8b_order_json", f) logits_processor = None prefix_fn = lfe_order_prefix() bench(data, context, execute, "lm_format_enforcer_llama3_8b_order_json", f) + # ---------------------------------------------------------------------------------------------------------- system_prompt = """[INST] You are a helpful AI assistant for information extraction. @@ -140,6 +144,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): torch.cuda.empty_cache() model, tokenizer = get_llama2_7b_tokenizer_and_model() model.eval() + # ---------------------------------------------------------------------------------------------------------- max_new_tokens = 50 inputs = load_address() logits_processor = get_address_schema() @@ -147,6 +152,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): logits_processor = None prefix_fn = lfe_address_prefix() bench(data, context, execute, "lm_format_enforcer_llama2_7b_address_json", f) + # ---------------------------------------------------------------------------------------------------------- max_new_tokens = 30 inputs = load_linkedlist() logits_processor = get_linkedlist_schema() @@ -155,6 +161,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): logits_processor = None prefix_fn = lfe_linkedlist_prefix() bench(data, context, execute, "lm_format_enforcer_llama2_7b_linkedlist_json", f) + # ---------------------------------------------------------------------------------------------------------- inputs = load_orders() logits_processor = get_order_schema() bench(data, context, execute, "formatron_llama2_7b_order_json", f) diff --git a/benchmarks/vllm_json.py b/benchmarks/vllm_json.py index 19fe62a5..2dda47ea 100644 --- a/benchmarks/vllm_json.py +++ b/benchmarks/vllm_json.py @@ -7,13 +7,15 @@ from vllm import LLM, SamplingParams from vllm.distributed import destroy_model_parallel, destroy_distributed_environment +from utils import load_address, load_linkedlist, load_orders +from utils import address_lfe, linked_list_lfe, order_lfe from formatter import FormatterBuilder from integrations.vllm import create_formatters_logits_processor, FormattersLogitsProcessor from utils import Address from utils import BenchResult, Context from utils import LinkedList from utils import Order, log - +from lmformatenforcer.integrations.vllm import build_vllm_logits_processor def execute(): prompts = [ @@ -28,27 +30,42 @@ def execute(): l[0].reset() -def get_vllm_address(): +def formatron_vllm_address(): f = FormatterBuilder() - f.append_line(f"```json\n{f.schema(Address, JsonGenerator(), capture_name='json')}```") + f.append_line(f"\n{f.schema(Address, JsonGenerator(), capture_name='json')}") logits_processor = create_formatters_logits_processor(llm, [f]) sampling_params = SamplingParams(temperature=0.8, top_p=0.95,max_tokens=100, logits_processors=[logits_processor]) return sampling_params -def get_vllm_linkedlist(): +def lfe_vllm_address(): + logits_processor = build_vllm_logits_processor(llm, address_lfe) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor]) + return sampling_params + +def formatron_vllm_linkedlist(): f = FormatterBuilder() - f.append_line(f"```json\n{f.schema(LinkedList, JsonGenerator(), capture_name='json')}```") + f.append_line(f"{f.schema(LinkedList, JsonGenerator(), capture_name='json')}") logits_processor = create_formatters_logits_processor(llm, [f]) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor]) return sampling_params -def get_vllm_order(): +def lfe_vllm_linkedlist(): + logits_processor = build_vllm_logits_processor(llm, linked_list_lfe) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor]) + return sampling_params + +def formatron_vllm_order(): f = FormatterBuilder() - f.append_line(f"```json\n{f.schema(Order, JsonGenerator(), capture_name='json')}```") + f.append_line(f"{f.schema(Order, JsonGenerator(), capture_name='json')}") logits_processor = create_formatters_logits_processor(llm, [f]) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256, logits_processors=[logits_processor]) return sampling_params +def lfe_vllm_order(): + logits_processor = build_vllm_logits_processor(llm, order_lfe) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256, logits_processors=[logits_processor]) + return sampling_params + def warm_up(f): f() context.index = 0 @@ -81,17 +98,27 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): Extract information into json format: """ tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" import os - os.environ["CUDA_VISIBLE_DEVICES"] = "1" + os.environ["CUDA_VISIBLE_DEVICES"] = "7" llm = LLM(model="NurtureAI/Meta-Llama-3-8B-Instruct-32k", max_model_len=4096) - inputs = json.load(open("address.json"))["sentences"] - sampling_params = get_vllm_address() - bench(data, context, execute, "llama3_8b_vllm_address", f) - sampling_params = get_vllm_linkedlist() - inputs = json.load(open("linkedlist.json"))["sentences"] - bench(data, context, execute, "llama3_8b_linkedlist", f) - sampling_params = get_vllm_order() - inputs = json.load(open("orders.json"))["orders"] - bench(data, context, execute, "llama3_8b_orders", f) + # -------------------------------------------------------------------------------------------------------------- + inputs = load_address() + sampling_params = formatron_vllm_address() + bench(data, context, execute, "formatron_llama3_8b_address", f) + sampling_params = lfe_vllm_address() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_address", f) + # -------------------------------------------------------------------------------------------------------------- + sampling_params = formatron_vllm_linkedlist() + inputs = load_linkedlist() + bench(data, context, execute, "formatron_llama3_8b_linkedlist", f) + sampling_params = lfe_vllm_linkedlist() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_linkedlist", f) + # -------------------------------------------------------------------------------------------------------------- + sampling_params = formatron_vllm_order() + inputs = load_orders() + bench(data, context, execute, "formatron_llama3_8b_orders", f) + sampling_params = lfe_vllm_order() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_order", f) + # -------------------------------------------------------------------------------------------------------------- destroy_model_parallel() destroy_distributed_environment() del llm.llm_engine.model_executor @@ -103,13 +130,22 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): Extract information into json format: """ tail = "[/INST]" - llm = LLM(model="daryl149/llama-2-7b-chat-hf", max_model_len=2048) - inputs = json.load(open("address.json"))["sentences"] - sampling_params = get_vllm_address() - bench(data, context,execute, "llama2_7b_vllm_address", f) - sampling_params = get_vllm_linkedlist() - inputs = json.load(open("linkedlist.json"))["sentences"] - bench(data, context, execute, "llama2_7b_linkedlist", f) - sampling_params = get_vllm_order() - inputs = json.load(open("orders.json"))["orders"] - bench(data, context, execute, "llama2_7b_orders", f) \ No newline at end of file + llm = LLM(model="togethercomputer/LLaMA-2-7B-32K", max_model_len=4096) + # -------------------------------------------------------------------------------------------------------------- + inputs = load_address() + sampling_params = formatron_vllm_address() + bench(data, context,execute, "formatron_llama2_7b_address", f) + sampling_params = lfe_vllm_address() + bench(data, context,execute, "lm_format_enforcer_llama2_7b_address", f) + # -------------------------------------------------------------------------------------------------------------- + sampling_params = formatron_vllm_linkedlist() + inputs = load_linkedlist() + bench(data, context, execute, "formatron_llama2_7b_linkedlist", f) + sampling_params = lfe_vllm_linkedlist() + bench(data, context, execute, "lm_format_enforcer_llama2_7b_linkedlist", f) + # -------------------------------------------------------------------------------------------------------------- + sampling_params = formatron_vllm_order() + inputs = load_orders() + bench(data, context, execute, "formatron_llama2_7b_orders", f) + sampling_params = lfe_vllm_order() + bench(data, context, execute, "lm_format_enforcer_llama2_7b_orders", f) \ No newline at end of file diff --git a/benchmarks/vllm_json_bench.txt b/benchmarks/vllm_json_bench.txt index 4e1d61fa..ed4b9888 100644 --- a/benchmarks/vllm_json_bench.txt +++ b/benchmarks/vllm_json_bench.txt @@ -1,12 +1,24 @@ -llama3_8b_vllm_address generated 1085 tokens with 40.81784246815824 tps (with warm up) -llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.941802458298596 tps -llama3_8b_linkedlist generated 1252 tokens with 40.55915109806643 tps (with warm up) -llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.85204307258889 tps -llama3_8b_orders generated 4266 tokens with 40.055533063427156 tps (with warm up) -llama3_8b_orders unconstrained generated 4595 tokens with 41.46372280395214 tps -llama2_7b_vllm_address generated 1956 tokens with 46.56639728880725 tps (with warm up) -llama2_7b_vllm_address unconstrained generated 1723 tokens with 47.53091890475664 tps -llama2_7b_linkedlist generated 1590 tokens with 46.51200582753253 tps (with warm up) -llama2_7b_linkedlist unconstrained generated 1915 tokens with 47.54641610018721 tps -llama2_7b_orders generated 5073 tokens with 45.7052185686332 tps (with warm up) -llama2_7b_orders unconstrained generated 5120 tokens with 46.68000079667531 tps +formatron_llama3_8b_address generated 837 tokens with 40.92543858162237 tps (with warm up) +formatron_llama3_8b_address unconstrained generated 1517 tokens with 42.03973057981289 tps +lm_format_enforcer_llama3_8b_address generated 900 tokens with 38.382952379822854 tps (with warm up) +lm_format_enforcer_llama3_8b_address unconstrained generated 1557 tokens with 41.95753842260187 tps +formatron_llama3_8b_linkedlist generated 1253 tokens with 40.65510144702983 tps (with warm up) +formatron_llama3_8b_linkedlist unconstrained generated 1280 tokens with 41.96170843472007 tps +lm_format_enforcer_llama3_8b_linkedlist generated 1046 tokens with 41.49508056791946 tps (with warm up) +lm_format_enforcer_llama3_8b_linkedlist unconstrained generated 1370 tokens with 41.93690985976558 tps +formatron_llama3_8b_orders generated 3846 tokens with 40.30895091864163 tps (with warm up) +formatron_llama3_8b_orders unconstrained generated 4577 tokens with 41.55286165339125 tps +lm_format_enforcer_llama3_8b_order generated 3767 tokens with 40.31444158325534 tps (with warm up) +lm_format_enforcer_llama3_8b_order unconstrained generated 4581 tokens with 41.55670827797047 tps +formatron_llama2_7b_address generated 1353 tokens with 46.79902509681338 tps (with warm up) +formatron_llama2_7b_address unconstrained generated 2000 tokens with 47.6050425184716 tps +lm_format_enforcer_llama2_7b_address generated 1614 tokens with 46.85370459223725 tps (with warm up) +lm_format_enforcer_llama2_7b_address unconstrained generated 2000 tokens with 47.591727879190344 tps +formatron_llama2_7b_linkedlist generated 1022 tokens with 46.64185711280541 tps (with warm up) +formatron_llama2_7b_linkedlist unconstrained generated 2000 tokens with 47.655055230835984 tps +lm_format_enforcer_llama2_7b_linkedlist generated 1432 tokens with 46.83670664952844 tps (with warm up) +lm_format_enforcer_llama2_7b_linkedlist unconstrained generated 2000 tokens with 47.650351474620415 tps +formatron_llama2_7b_orders generated 4596 tokens with 45.86432176919425 tps (with warm up) +formatron_llama2_7b_orders unconstrained generated 5120 tokens with 46.742264169348076 tps +lm_format_enforcer_llama2_7b_orders generated 4834 tokens with 45.98683334304331 tps (with warm up) +lm_format_enforcer_llama2_7b_orders unconstrained generated 5021 tokens with 46.744708931481135 tps