Skip to content

Commit

Permalink
Include lm-format-enforcer in benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 23, 2024
1 parent d8a9275 commit 40f4632
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 105 deletions.
126 changes: 82 additions & 44 deletions benchmarks/exllamav2_json.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import gc
import json
from timeit import timeit

import torch
import formatron.integrations.exllamav2
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
from formatron.formatter import FormatterBuilder
from formatron.grammar_generators.json_generator import JsonGenerator
from formatron.integrations.exllamav2 import create_formatter_filter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter

from utils import Order
from benchmarks.utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \
order_lfe
from test_grammar_gen import LinkedList
from utils import Address, BenchResult, Context, log
from utils import Order


def create_exllamav2_6bpw_llama3_8b():
Expand Down Expand Up @@ -42,27 +42,38 @@ def create_exllamav2_4bpw_llama2_7b():
)
return generator

def get_address_filter():
def f_get_address_filter():
f = FormatterBuilder()
f.append_line(f"{f.schema(Address, JsonGenerator(), capture_name='json')}")
exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f)
return exllama_filter

def get_linkedlist_filter():
def lfe_get_address_filter():
exllama_filter = ExLlamaV2TokenEnforcerFilter(address_lfe, generator.tokenizer)
return exllama_filter

def f_get_linkedlist_filter():
f = FormatterBuilder()
f.append_line(f"{f.schema(LinkedList, JsonGenerator(), capture_name='json')}")
exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f)
return exllama_filter

def get_order_filter():
def lfe_get_linkedlist_filter():
exllama_filter = ExLlamaV2TokenEnforcerFilter(linked_list_lfe, generator.tokenizer)
return exllama_filter

def f_get_order_filter():
f = FormatterBuilder()
f.append_line(f"{f.schema(Order, JsonGenerator(), capture_name='json')}")
exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f)
return exllama_filter

def lfe_get_order_filter():
exllama_filter = ExLlamaV2TokenEnforcerFilter(order_lfe, generator.tokenizer)
return exllama_filter

def execute():
prompt = f"""{system_prompt}{inputs[context.index]}<|eot_id|><|start_header_id|>assistant<|end_header_id|> Sure! Here is the json: """
prompt = f"""{system_prompt}{inputs[context.index]}{tail} Sure! Here is the json: """
output = generator.generate(
prompt=prompt,
max_new_tokens=max_new_tokens,
Expand All @@ -74,7 +85,12 @@ def execute():
)
context.index += 1
if context.filters:
context.tokens += len(context.filters[0]._formatter._token_ids)
if isinstance(context.filters[0],formatron.integrations.exllamav2.FormatterFilter):
context.tokens += len(context.filters[0]._formatter._token_ids)
elif isinstance(context.filters[0], ExLlamaV2TokenEnforcerFilter):
context.tokens += len(context.filters[0].token_sequence)
else:
raise ValueError(f"Unsupported filter {type(context.filters[0])}")
else:
assert not output.endswith(generator.tokenizer.eos_token), "Something is wrong"
context.tokens += max_new_tokens
Expand All @@ -85,6 +101,7 @@ def warm_up(f):
context.tokens = 0

def bench(result:BenchResult, context:Context,func, bench_name:str, f):
global settings
context.index = 0
context.tokens = 0
result.s1 = (timeit(func, setup=lambda: warm_up(func), number=len(inputs)))
Expand All @@ -95,50 +112,71 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f):
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
result.s2 = (timeit(func, number=len(inputs)))
result.t2 = context.tokens
settings = ExLlamaV2Sampler.Settings()
log(bench_name, result, f)


if __name__ == '__main__':
system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|>

Extract information into json format: """
data = BenchResult(0, 0, 0, 0)
context = Context(0, 0)
with open("exllamav2_json.txt", "w") as f:
generator = create_exllamav2_6bpw_llama3_8b()
system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|>
Extract information into json format: """
tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
settings = ExLlamaV2Sampler.Settings()
inputs = json.load(open("address.json"))["sentences"]
context.filters = [get_address_filter()]
# --------------------------------------------------------------------------------------------------------------
inputs = load_address()
context.filters = [f_get_address_filter()]
max_new_tokens = 100
bench(data, context, execute, "llama3_8b_6pw_exl2_address_json_exllamav2", f)
settings = ExLlamaV2Sampler.Settings()
context.filters = [get_linkedlist_filter()]
inputs = json.load(open("linkedlist.json"))["sentences"]
max_new_tokens = 32
bench(data, context, execute, "llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f)
settings = ExLlamaV2Sampler.Settings()
context.filters = [get_order_filter()]
inputs = json.load(open("orders.json"))["orders"]
max_new_tokens = 160
bench(data, context, execute, "llama3_8b_6pw_exl2_orders_json_exllamav2", f)
bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_address_json_exllamav2", f)
context.filters = [lfe_get_address_filter()]
bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2", f)
# --------------------------------------------------------------------------------------------------------------
context.filters = [f_get_linkedlist_filter()]
inputs = load_linkedlist()
max_new_tokens = 50
bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f)
context.filters = [lfe_get_linkedlist_filter()]
bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f)
# --------------------------------------------------------------------------------------------------------------
context.filters = [f_get_order_filter()]
inputs = load_orders()
max_new_tokens = 200
bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_orders_json_exllamav2", f)
context.filters = [lfe_get_order_filter()]
bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2", f)
# --------------------------------------------------------------------------------------------------------------
del generator
gc.collect()
torch.cuda.empty_cache()
force_gc()
generator = create_exllamav2_4bpw_llama2_7b()
settings = ExLlamaV2Sampler.Settings()
inputs = json.load(open("address.json"))["sentences"]
context.filters = [get_address_filter()]
max_new_tokens = 120
bench(data, context, execute, "llama2_7b_4pw_exl2_address_json_exllamav2", f)
settings = ExLlamaV2Sampler.Settings()
context.filters = [get_linkedlist_filter()]
inputs = json.load(open("linkedlist.json"))["sentences"]
max_new_tokens = 15
bench(data, context, execute, "llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f)
settings = ExLlamaV2Sampler.Settings()
context.filters = [get_order_filter()]
inputs = json.load(open("orders.json"))["orders"]
max_new_tokens = 160
bench(data, context, execute, "llama2_7b_4pw_exl2_orders_json_exllamav2", f)
system_prompt = """[INST]
You are a helpful AI assistant for information extraction.
Extract information into json format: """
tail = "[/INST]"
# --------------------------------------------------------------------------------------------------------------
inputs = load_address()
context.filters = [f_get_address_filter()]
max_new_tokens = 100
bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_address_json_exllamav2", f)
context.filters = [lfe_get_address_filter()]
bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2", f)
# --------------------------------------------------------------------------------------------------------------
context.filters = [f_get_linkedlist_filter()]
inputs = load_linkedlist()
max_new_tokens = 50
bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f)
context.filters = [lfe_get_linkedlist_filter()]
bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f)
# --------------------------------------------------------------------------------------------------------------
context.filters = [f_get_order_filter()]
inputs = load_orders()
max_new_tokens = 200
bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_orders_json_exllamav2", f)
context.filters = [lfe_get_order_filter()]
bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2", f)
36 changes: 24 additions & 12 deletions benchmarks/exllamav2_json.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
llama3_8b_6pw_exl2_address_json_exllamav2 generated 1937 tokens with 81.76457267212113 tps (with warm up)
llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 91.93855585432294 tps
llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 567 tokens with 73.72004132348941 tps (with warm up)
llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 640 tokens with 92.92655429712437 tps
llama3_8b_6pw_exl2_orders_json_exllamav2 generated 2976 tokens with 79.10910035605352 tps (with warm up)
llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 3200 tokens with 93.46945772542723 tps
llama2_7b_4pw_exl2_address_json_exllamav2 generated 2400 tokens with 123.7077165970634 tps (with warm up)
llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2400 tokens with 133.37570270534903 tps
llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 250 tokens with 80.04987619935734 tps (with warm up)
llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 300 tokens with 132.19982863147897 tps
llama2_7b_4pw_exl2_orders_json_exllamav2 generated 3136 tokens with 117.27953013576354 tps (with warm up)
llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 3200 tokens with 129.65265959777014 tps
formatron_llama3_8b_6pw_exl2_address_json_exllamav2 generated 1937 tokens with 81.90114209213694 tps (with warm up)
formatron_llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 92.93084312000893 tps
lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2 generated 2000 tokens with 44.49453610169303 tps (with warm up)
lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 92.60734707373982 tps
formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 712 tokens with 62.301888356328774 tps (with warm up)
formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 94.1753542487434 tps
lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 1000 tokens with 50.55230703666647 tps (with warm up)
lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 93.80809481125524 tps
formatron_llama3_8b_6pw_exl2_orders_json_exllamav2 generated 3282 tokens with 71.34696208336447 tps (with warm up)
formatron_llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 94.47558818852826 tps
lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2 generated 4000 tokens with 44.2748717455552 tps (with warm up)
lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 94.47549746230774 tps
formatron_llama2_7b_4pw_exl2_address_json_exllamav2 generated 1895 tokens with 116.77613285626795 tps (with warm up)
formatron_llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 135.14145208580499 tps
lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2 generated 2000 tokens with 122.83497659799589 tps (with warm up)
lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 135.06497241415678 tps
formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 749 tokens with 94.48129844732799 tps (with warm up)
formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 135.2508753195326 tps
lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 1000 tokens with 131.53638682928263 tps (with warm up)
lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 135.1702746403635 tps
formatron_llama2_7b_4pw_exl2_orders_json_exllamav2 generated 3672 tokens with 113.24874134695554 tps (with warm up)
formatron_llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 131.19728962397383 tps
lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2 generated 4000 tokens with 121.46458095557651 tps (with warm up)
lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 131.21070982754418 tps
36 changes: 17 additions & 19 deletions benchmarks/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,25 @@ Default vllm setting are used.
| Llama2-7B(fp16) | linkedlist_json | 46.51 | 47.54 | 0.46 |
| Llama2-7B(fp16) | order_json | 45.71 | 46.68 | 0.46 |
## Exllamav2
Default exllamav2 setting are used. The inferior performance of exllamav2 integration
can be attributed to the fact that `Exllamav2Filter` requires the implementation to return
a set of allowed tokens, and constructing a large set is very slow in Python.
Default exllamav2 setting are used.

| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms |
|------------------------|-----------------|---------------------------------|---------------------|-------------------------|
| Llama3-8B(6.0bpw-exl2) | address_json | 81.76 | 91.94 | 1.36 |
| Llama3-8B(6.0bpw-exl2) | linkedlist_json | 73.73 | 92.93 | 2.82 |
| Llama3-8B(6.0bpw-exl2) | order_json | 79.11 | 93.47 | 1.96 |
| Llama2-7B(4.0bpw-exl2) | address_json | 123.71 | 133.38 | 0.55 |
| Llama2-7B(4.0bpw-exl2) | linkedlist_json | 80.05 | 132.20 | 4.90 |
| Llama2-7B(4.0bpw-exl2) | order_json | 117.28 | 129.65 | 0.82 |
| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms |
|------------------------|-----------------|-------------------------------------------------|----------------------------------------------------------|
| Llama3-8B(6.0bpw-exl2) | address_json | 1.4 | 11.7 |
| Llama3-8B(6.0bpw-exl2) | linkedlist_json | 5.4 | 9.1 |
| Llama3-8B(6.0bpw-exl2) | order_json | 3.4 | 12.1 |
| Llama2-7B(4.0bpw-exl2) | address_json | 1.2 | 0.73 |
| Llama2-7B(4.0bpw-exl2) | linkedlist_json | 3.2 | 0.20 |
| Llama2-7B(4.0bpw-exl2) | order_json | 1.2 | 0.60 |

## Transformers
Default transformers setting with flash attention v2 enabled.

| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms |
|-----------------|-----------------|---------------------------------|---------------------|-------------------------|
| Llama3-8B(bf16) | address_json | 37.39 | 38.71 | 0.91 |
| Llama3-8B(bf16) | linkedlist_json | 37.25 | 38.65 | 0.98 |
| Llama3-8B(bf16) | order_json | 36.73 | 38.11 | 0.99 |
| Llama2-7B(fp16) | address_json | 41.30 | 42.14 | 0.48 |
| Llama2-7B(fp16) | linkedlist_json | 40.75 | 41.91 | 0.68 |
| Llama2-7B(fp16) | order_json | 39.70 | 40.41 | 0.44 |
| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms |
|-----------------|-----------------|-------------------------------------------------|----------------------------------------------------------|
| Llama3-8B(bf16) | address_json | 0.65 | 9.3 |
| Llama3-8B(bf16) | linkedlist_json | 0.70 | 3.5 |
| Llama3-8B(bf16) | order_json | 0.69 | 6.1 |
| Llama2-7B(fp16) | address_json | 0.41 | 1.4 |
| Llama2-7B(fp16) | linkedlist_json | 0.54 | 0.58 |
| Llama2-7B(fp16) | order_json | 0.44 | 0.96 |
Loading

0 comments on commit 40f4632

Please sign in to comment.