Skip to content

Commit

Permalink
Update vllm's lm-format-enforcer benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 23, 2024
1 parent 40f4632 commit 7100cfd
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 53 deletions.
16 changes: 10 additions & 6 deletions benchmarks/exllamav2_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from formatron.integrations.exllamav2 import create_formatter_filter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter

from benchmarks.utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \
from utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \
order_lfe
from test_grammar_gen import LinkedList
from utils import Address, BenchResult, Context, log
from utils import Order


def create_exllamav2_6bpw_llama3_8b():
model_dir = "../tests/local_assets/Llama-3-8B-exl2/"
model_dir = "../tests/local_assets/Meta-Llama-3-8B-Instruct-32k/"
config = ExLlamaV2Config(model_dir)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len=65536, lazy=True)
cache = ExLlamaV2Cache(model, max_seq_len=4096, lazy=True)
model.load_autosplit(cache, progress=True)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2DynamicGenerator(
Expand All @@ -29,10 +29,10 @@ def create_exllamav2_6bpw_llama3_8b():
return generator

def create_exllamav2_4bpw_llama2_7b():
model_dir = "../tests/local_assets/Llama-2-7b-chat-hf-4.0-bpw-exl2/"
model_dir = "../tests/local_assets/LLaMA-2-7B-32K/"
config = ExLlamaV2Config(model_dir)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len=65536, lazy=True)
cache = ExLlamaV2Cache(model, max_seq_len=4096, lazy=True)
model.load_autosplit(cache, progress=True)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2DynamicGenerator(
Expand Down Expand Up @@ -122,6 +122,10 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
context = Context(0, 0)
with open("exllamav2_json.txt", "w") as f:
generator = create_exllamav2_6bpw_llama3_8b()
settings = ExLlamaV2Sampler.Settings()
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
generator.generate("Something", max_new_tokens=4080, gen_settings=settings) # warm up exllamav2 itself
settings = ExLlamaV2Sampler.Settings()
system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|>
Expand All @@ -132,7 +136,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
# --------------------------------------------------------------------------------------------------------------
inputs = load_address()
context.filters = [f_get_address_filter()]
max_new_tokens = 100
max_new_tokens = 50
bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_address_json_exllamav2", f)
context.filters = [lfe_get_address_filter()]
bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2", f)
Expand Down
17 changes: 9 additions & 8 deletions benchmarks/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ schema creation to the first run ends.
## vllm
Default vllm setting are used.

| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms |
|-----------------|-----------------|---------------------------------|---------------------|-------------------------|
| Llama3-8B(bf16) | address_json | 40.82 | 41.94 | 0.65 |
| Llama3-8B(bf16) | linkedlist_json | 40.56 | 41.85 | 0.76 |
| Llama3-8B(bf16) | order_json | 40.05 | 41.46 | 0.84 |
| Llama2-7B(fp16) | address_json | 46.57 | 47.53 | 0.44 |
| Llama2-7B(fp16) | linkedlist_json | 46.51 | 47.54 | 0.46 |
| Llama2-7B(fp16) | order_json | 45.71 | 46.68 | 0.46 |
| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms |
|-----------------|-----------------|-------------------------------------------------|----------------------------------------------------------|
| Llama3-8B(bf16) | address_json | 0.59 | 2.31 |
| Llama3-8B(bf16) | linkedlist_json | 0.66 | 0.26 |
| Llama3-8B(bf16) | order_json | 0.64 | 0.92 |
| Llama2-7B(fp16) | address_json | 0.33 | 0.33 |
| Llama2-7B(fp16) | linkedlist_json | 0.45 | 0.36 |
| Llama2-7B(fp16) | order_json | 0.40 | 0.34 |
## Exllamav2
Default exllamav2 setting are used.
Quantization likely has some influence on json outputs and hence affects the performance.

| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms |
|------------------------|-----------------|-------------------------------------------------|----------------------------------------------------------|
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/transformers_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
model, tokenizer = get_llama3_8b_tokenizer_and_model()
model.eval()
# ----------------------------------------------------------------------------------------------------------
max_new_tokens = 50
inputs = load_address()
prefix_fn = None
Expand All @@ -116,19 +117,22 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
logits_processor = None
prefix_fn = lfe_address_prefix()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_address_json", f)
# ----------------------------------------------------------------------------------------------------------
inputs = load_linkedlist()
logits_processor = get_linkedlist_schema()
max_new_tokens = 200
bench(data,context,execute, "formatron_llama3_8b_linkedlist_json", f)
logits_processor = None
prefix_fn = lfe_linkedlist_prefix()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_linkedlist_json", f)
# ----------------------------------------------------------------------------------------------------------
inputs = load_orders()
logits_processor = get_order_schema()
bench(data, context, execute, "formatron_llama3_8b_order_json", f)
logits_processor = None
prefix_fn = lfe_order_prefix()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_order_json", f)
# ----------------------------------------------------------------------------------------------------------
system_prompt = """[INST]
You are a helpful AI assistant for information extraction.
Expand All @@ -140,13 +144,15 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
torch.cuda.empty_cache()
model, tokenizer = get_llama2_7b_tokenizer_and_model()
model.eval()
# ----------------------------------------------------------------------------------------------------------
max_new_tokens = 50
inputs = load_address()
logits_processor = get_address_schema()
bench(data, context, execute, "formatron_llama2_7b_address_json", f)
logits_processor = None
prefix_fn = lfe_address_prefix()
bench(data, context, execute, "lm_format_enforcer_llama2_7b_address_json", f)
# ----------------------------------------------------------------------------------------------------------
max_new_tokens = 30
inputs = load_linkedlist()
logits_processor = get_linkedlist_schema()
Expand All @@ -155,6 +161,7 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
logits_processor = None
prefix_fn = lfe_linkedlist_prefix()
bench(data, context, execute, "lm_format_enforcer_llama2_7b_linkedlist_json", f)
# ----------------------------------------------------------------------------------------------------------
inputs = load_orders()
logits_processor = get_order_schema()
bench(data, context, execute, "formatron_llama2_7b_order_json", f)
Expand Down
90 changes: 63 additions & 27 deletions benchmarks/vllm_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from vllm import LLM, SamplingParams
from vllm.distributed import destroy_model_parallel, destroy_distributed_environment

from utils import load_address, load_linkedlist, load_orders
from utils import address_lfe, linked_list_lfe, order_lfe
from formatter import FormatterBuilder
from integrations.vllm import create_formatters_logits_processor, FormattersLogitsProcessor
from utils import Address
from utils import BenchResult, Context
from utils import LinkedList
from utils import Order, log

from lmformatenforcer.integrations.vllm import build_vllm_logits_processor

def execute():
prompts = [
Expand All @@ -28,27 +30,42 @@ def execute():
l[0].reset()


def get_vllm_address():
def formatron_vllm_address():
f = FormatterBuilder()
f.append_line(f"```json\n{f.schema(Address, JsonGenerator(), capture_name='json')}```")
f.append_line(f"\n{f.schema(Address, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatters_logits_processor(llm, [f])
sampling_params = SamplingParams(temperature=0.8, top_p=0.95,max_tokens=100, logits_processors=[logits_processor])
return sampling_params

def get_vllm_linkedlist():
def lfe_vllm_address():
logits_processor = build_vllm_logits_processor(llm, address_lfe)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor])
return sampling_params

def formatron_vllm_linkedlist():
f = FormatterBuilder()
f.append_line(f"```json\n{f.schema(LinkedList, JsonGenerator(), capture_name='json')}```")
f.append_line(f"{f.schema(LinkedList, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatters_logits_processor(llm, [f])
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor])
return sampling_params

def get_vllm_order():
def lfe_vllm_linkedlist():
logits_processor = build_vllm_logits_processor(llm, linked_list_lfe)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100, logits_processors=[logits_processor])
return sampling_params

def formatron_vllm_order():
f = FormatterBuilder()
f.append_line(f"```json\n{f.schema(Order, JsonGenerator(), capture_name='json')}```")
f.append_line(f"{f.schema(Order, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatters_logits_processor(llm, [f])
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256, logits_processors=[logits_processor])
return sampling_params

def lfe_vllm_order():
logits_processor = build_vllm_logits_processor(llm, order_lfe)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256, logits_processors=[logits_processor])
return sampling_params

def warm_up(f):
f()
context.index = 0
Expand Down Expand Up @@ -81,17 +98,27 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
Extract information into json format: """
tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
llm = LLM(model="NurtureAI/Meta-Llama-3-8B-Instruct-32k", max_model_len=4096)
inputs = json.load(open("address.json"))["sentences"]
sampling_params = get_vllm_address()
bench(data, context, execute, "llama3_8b_vllm_address", f)
sampling_params = get_vllm_linkedlist()
inputs = json.load(open("linkedlist.json"))["sentences"]
bench(data, context, execute, "llama3_8b_linkedlist", f)
sampling_params = get_vllm_order()
inputs = json.load(open("orders.json"))["orders"]
bench(data, context, execute, "llama3_8b_orders", f)
# --------------------------------------------------------------------------------------------------------------
inputs = load_address()
sampling_params = formatron_vllm_address()
bench(data, context, execute, "formatron_llama3_8b_address", f)
sampling_params = lfe_vllm_address()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_address", f)
# --------------------------------------------------------------------------------------------------------------
sampling_params = formatron_vllm_linkedlist()
inputs = load_linkedlist()
bench(data, context, execute, "formatron_llama3_8b_linkedlist", f)
sampling_params = lfe_vllm_linkedlist()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_linkedlist", f)
# --------------------------------------------------------------------------------------------------------------
sampling_params = formatron_vllm_order()
inputs = load_orders()
bench(data, context, execute, "formatron_llama3_8b_orders", f)
sampling_params = lfe_vllm_order()
bench(data, context, execute, "lm_format_enforcer_llama3_8b_order", f)
# --------------------------------------------------------------------------------------------------------------
destroy_model_parallel()
destroy_distributed_environment()
del llm.llm_engine.model_executor
Expand All @@ -103,13 +130,22 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
Extract information into json format: """
tail = "[/INST]"
llm = LLM(model="daryl149/llama-2-7b-chat-hf", max_model_len=2048)
inputs = json.load(open("address.json"))["sentences"]
sampling_params = get_vllm_address()
bench(data, context,execute, "llama2_7b_vllm_address", f)
sampling_params = get_vllm_linkedlist()
inputs = json.load(open("linkedlist.json"))["sentences"]
bench(data, context, execute, "llama2_7b_linkedlist", f)
sampling_params = get_vllm_order()
inputs = json.load(open("orders.json"))["orders"]
bench(data, context, execute, "llama2_7b_orders", f)
llm = LLM(model="togethercomputer/LLaMA-2-7B-32K", max_model_len=4096)
# --------------------------------------------------------------------------------------------------------------
inputs = load_address()
sampling_params = formatron_vllm_address()
bench(data, context,execute, "formatron_llama2_7b_address", f)
sampling_params = lfe_vllm_address()
bench(data, context,execute, "lm_format_enforcer_llama2_7b_address", f)
# --------------------------------------------------------------------------------------------------------------
sampling_params = formatron_vllm_linkedlist()
inputs = load_linkedlist()
bench(data, context, execute, "formatron_llama2_7b_linkedlist", f)
sampling_params = lfe_vllm_linkedlist()
bench(data, context, execute, "lm_format_enforcer_llama2_7b_linkedlist", f)
# --------------------------------------------------------------------------------------------------------------
sampling_params = formatron_vllm_order()
inputs = load_orders()
bench(data, context, execute, "formatron_llama2_7b_orders", f)
sampling_params = lfe_vllm_order()
bench(data, context, execute, "lm_format_enforcer_llama2_7b_orders", f)
36 changes: 24 additions & 12 deletions benchmarks/vllm_json_bench.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
llama3_8b_vllm_address generated 1085 tokens with 40.81784246815824 tps (with warm up)
llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.941802458298596 tps
llama3_8b_linkedlist generated 1252 tokens with 40.55915109806643 tps (with warm up)
llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.85204307258889 tps
llama3_8b_orders generated 4266 tokens with 40.055533063427156 tps (with warm up)
llama3_8b_orders unconstrained generated 4595 tokens with 41.46372280395214 tps
llama2_7b_vllm_address generated 1956 tokens with 46.56639728880725 tps (with warm up)
llama2_7b_vllm_address unconstrained generated 1723 tokens with 47.53091890475664 tps
llama2_7b_linkedlist generated 1590 tokens with 46.51200582753253 tps (with warm up)
llama2_7b_linkedlist unconstrained generated 1915 tokens with 47.54641610018721 tps
llama2_7b_orders generated 5073 tokens with 45.7052185686332 tps (with warm up)
llama2_7b_orders unconstrained generated 5120 tokens with 46.68000079667531 tps
formatron_llama3_8b_address generated 837 tokens with 40.92543858162237 tps (with warm up)
formatron_llama3_8b_address unconstrained generated 1517 tokens with 42.03973057981289 tps
lm_format_enforcer_llama3_8b_address generated 900 tokens with 38.382952379822854 tps (with warm up)
lm_format_enforcer_llama3_8b_address unconstrained generated 1557 tokens with 41.95753842260187 tps
formatron_llama3_8b_linkedlist generated 1253 tokens with 40.65510144702983 tps (with warm up)
formatron_llama3_8b_linkedlist unconstrained generated 1280 tokens with 41.96170843472007 tps
lm_format_enforcer_llama3_8b_linkedlist generated 1046 tokens with 41.49508056791946 tps (with warm up)
lm_format_enforcer_llama3_8b_linkedlist unconstrained generated 1370 tokens with 41.93690985976558 tps
formatron_llama3_8b_orders generated 3846 tokens with 40.30895091864163 tps (with warm up)
formatron_llama3_8b_orders unconstrained generated 4577 tokens with 41.55286165339125 tps
lm_format_enforcer_llama3_8b_order generated 3767 tokens with 40.31444158325534 tps (with warm up)
lm_format_enforcer_llama3_8b_order unconstrained generated 4581 tokens with 41.55670827797047 tps
formatron_llama2_7b_address generated 1353 tokens with 46.79902509681338 tps (with warm up)
formatron_llama2_7b_address unconstrained generated 2000 tokens with 47.6050425184716 tps
lm_format_enforcer_llama2_7b_address generated 1614 tokens with 46.85370459223725 tps (with warm up)
lm_format_enforcer_llama2_7b_address unconstrained generated 2000 tokens with 47.591727879190344 tps
formatron_llama2_7b_linkedlist generated 1022 tokens with 46.64185711280541 tps (with warm up)
formatron_llama2_7b_linkedlist unconstrained generated 2000 tokens with 47.655055230835984 tps
lm_format_enforcer_llama2_7b_linkedlist generated 1432 tokens with 46.83670664952844 tps (with warm up)
lm_format_enforcer_llama2_7b_linkedlist unconstrained generated 2000 tokens with 47.650351474620415 tps
formatron_llama2_7b_orders generated 4596 tokens with 45.86432176919425 tps (with warm up)
formatron_llama2_7b_orders unconstrained generated 5120 tokens with 46.742264169348076 tps
lm_format_enforcer_llama2_7b_orders generated 4834 tokens with 45.98683334304331 tps (with warm up)
lm_format_enforcer_llama2_7b_orders unconstrained generated 5021 tokens with 46.744708931481135 tps

0 comments on commit 7100cfd

Please sign in to comment.