diff --git a/benchmarks/exllamav2_json.py b/benchmarks/exllamav2_json.py index b3f0aff8..f48fa300 100644 --- a/benchmarks/exllamav2_json.py +++ b/benchmarks/exllamav2_json.py @@ -1,17 +1,17 @@ -import gc -import json from timeit import timeit - -import torch +import formatron.integrations.exllamav2 from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler from formatron.formatter import FormatterBuilder from formatron.grammar_generators.json_generator import JsonGenerator from formatron.integrations.exllamav2 import create_formatter_filter +from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter -from utils import Order +from benchmarks.utils import load_address, load_linkedlist, load_orders, force_gc, address_lfe, linked_list_lfe, \ + order_lfe from test_grammar_gen import LinkedList from utils import Address, BenchResult, Context, log +from utils import Order def create_exllamav2_6bpw_llama3_8b(): @@ -42,27 +42,38 @@ def create_exllamav2_4bpw_llama2_7b(): ) return generator -def get_address_filter(): +def f_get_address_filter(): f = FormatterBuilder() f.append_line(f"{f.schema(Address, JsonGenerator(), capture_name='json')}") exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f) return exllama_filter -def get_linkedlist_filter(): +def lfe_get_address_filter(): + exllama_filter = ExLlamaV2TokenEnforcerFilter(address_lfe, generator.tokenizer) + return exllama_filter + +def f_get_linkedlist_filter(): f = FormatterBuilder() f.append_line(f"{f.schema(LinkedList, JsonGenerator(), capture_name='json')}") exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f) return exllama_filter -def get_order_filter(): +def lfe_get_linkedlist_filter(): + exllama_filter = ExLlamaV2TokenEnforcerFilter(linked_list_lfe, generator.tokenizer) + return exllama_filter + +def f_get_order_filter(): f = FormatterBuilder() f.append_line(f"{f.schema(Order, JsonGenerator(), capture_name='json')}") exllama_filter = create_formatter_filter(generator.model, generator.tokenizer, f) return exllama_filter +def lfe_get_order_filter(): + exllama_filter = ExLlamaV2TokenEnforcerFilter(order_lfe, generator.tokenizer) + return exllama_filter def execute(): - prompt = f"""{system_prompt}{inputs[context.index]}<|eot_id|><|start_header_id|>assistant<|end_header_id|> Sure! Here is the json: """ + prompt = f"""{system_prompt}{inputs[context.index]}{tail} Sure! Here is the json: """ output = generator.generate( prompt=prompt, max_new_tokens=max_new_tokens, @@ -74,7 +85,12 @@ def execute(): ) context.index += 1 if context.filters: - context.tokens += len(context.filters[0]._formatter._token_ids) + if isinstance(context.filters[0],formatron.integrations.exllamav2.FormatterFilter): + context.tokens += len(context.filters[0]._formatter._token_ids) + elif isinstance(context.filters[0], ExLlamaV2TokenEnforcerFilter): + context.tokens += len(context.filters[0].token_sequence) + else: + raise ValueError(f"Unsupported filter {type(context.filters[0])}") else: assert not output.endswith(generator.tokenizer.eos_token), "Something is wrong" context.tokens += max_new_tokens @@ -85,6 +101,7 @@ def warm_up(f): context.tokens = 0 def bench(result:BenchResult, context:Context,func, bench_name:str, f): + global settings context.index = 0 context.tokens = 0 result.s1 = (timeit(func, setup=lambda: warm_up(func), number=len(inputs))) @@ -95,50 +112,71 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) result.s2 = (timeit(func, number=len(inputs))) result.t2 = context.tokens + settings = ExLlamaV2Sampler.Settings() log(bench_name, result, f) if __name__ == '__main__': - system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> - -You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|> -Extract information into json format: """ data = BenchResult(0, 0, 0, 0) context = Context(0, 0) with open("exllamav2_json.txt", "w") as f: generator = create_exllamav2_6bpw_llama3_8b() + system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a helpful AI assistant for information extraction<|eot_id|><|start_header_id|>user<|end_header_id|> + + Extract information into json format: """ + tail = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" settings = ExLlamaV2Sampler.Settings() - inputs = json.load(open("address.json"))["sentences"] - context.filters = [get_address_filter()] + # -------------------------------------------------------------------------------------------------------------- + inputs = load_address() + context.filters = [f_get_address_filter()] max_new_tokens = 100 - bench(data, context, execute, "llama3_8b_6pw_exl2_address_json_exllamav2", f) - settings = ExLlamaV2Sampler.Settings() - context.filters = [get_linkedlist_filter()] - inputs = json.load(open("linkedlist.json"))["sentences"] - max_new_tokens = 32 - bench(data, context, execute, "llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f) - settings = ExLlamaV2Sampler.Settings() - context.filters = [get_order_filter()] - inputs = json.load(open("orders.json"))["orders"] - max_new_tokens = 160 - bench(data, context, execute, "llama3_8b_6pw_exl2_orders_json_exllamav2", f) + bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_address_json_exllamav2", f) + context.filters = [lfe_get_address_filter()] + bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2", f) + # -------------------------------------------------------------------------------------------------------------- + context.filters = [f_get_linkedlist_filter()] + inputs = load_linkedlist() + max_new_tokens = 50 + bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f) + context.filters = [lfe_get_linkedlist_filter()] + bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2", f) + # -------------------------------------------------------------------------------------------------------------- + context.filters = [f_get_order_filter()] + inputs = load_orders() + max_new_tokens = 200 + bench(data, context, execute, "formatron_llama3_8b_6pw_exl2_orders_json_exllamav2", f) + context.filters = [lfe_get_order_filter()] + bench(data, context, execute, "lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2", f) + # -------------------------------------------------------------------------------------------------------------- del generator - gc.collect() - torch.cuda.empty_cache() + force_gc() generator = create_exllamav2_4bpw_llama2_7b() - settings = ExLlamaV2Sampler.Settings() - inputs = json.load(open("address.json"))["sentences"] - context.filters = [get_address_filter()] - max_new_tokens = 120 - bench(data, context, execute, "llama2_7b_4pw_exl2_address_json_exllamav2", f) - settings = ExLlamaV2Sampler.Settings() - context.filters = [get_linkedlist_filter()] - inputs = json.load(open("linkedlist.json"))["sentences"] - max_new_tokens = 15 - bench(data, context, execute, "llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f) - settings = ExLlamaV2Sampler.Settings() - context.filters = [get_order_filter()] - inputs = json.load(open("orders.json"))["orders"] - max_new_tokens = 160 - bench(data, context, execute, "llama2_7b_4pw_exl2_orders_json_exllamav2", f) \ No newline at end of file + system_prompt = """[INST] + You are a helpful AI assistant for information extraction. + + Extract information into json format: """ + tail = "[/INST]" + # -------------------------------------------------------------------------------------------------------------- + inputs = load_address() + context.filters = [f_get_address_filter()] + max_new_tokens = 100 + bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_address_json_exllamav2", f) + context.filters = [lfe_get_address_filter()] + bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2", f) + # -------------------------------------------------------------------------------------------------------------- + context.filters = [f_get_linkedlist_filter()] + inputs = load_linkedlist() + max_new_tokens = 50 + bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f) + context.filters = [lfe_get_linkedlist_filter()] + bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2", f) + # -------------------------------------------------------------------------------------------------------------- + context.filters = [f_get_order_filter()] + inputs = load_orders() + max_new_tokens = 200 + bench(data, context, execute, "formatron_llama2_7b_4pw_exl2_orders_json_exllamav2", f) + context.filters = [lfe_get_order_filter()] + bench(data, context, execute, "lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2", f) \ No newline at end of file diff --git a/benchmarks/exllamav2_json.txt b/benchmarks/exllamav2_json.txt index 9f2a7c4a..17bf3bb5 100644 --- a/benchmarks/exllamav2_json.txt +++ b/benchmarks/exllamav2_json.txt @@ -1,12 +1,24 @@ -llama3_8b_6pw_exl2_address_json_exllamav2 generated 1937 tokens with 81.76457267212113 tps (with warm up) -llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 91.93855585432294 tps -llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 567 tokens with 73.72004132348941 tps (with warm up) -llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 640 tokens with 92.92655429712437 tps -llama3_8b_6pw_exl2_orders_json_exllamav2 generated 2976 tokens with 79.10910035605352 tps (with warm up) -llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 3200 tokens with 93.46945772542723 tps -llama2_7b_4pw_exl2_address_json_exllamav2 generated 2400 tokens with 123.7077165970634 tps (with warm up) -llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2400 tokens with 133.37570270534903 tps -llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 250 tokens with 80.04987619935734 tps (with warm up) -llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 300 tokens with 132.19982863147897 tps -llama2_7b_4pw_exl2_orders_json_exllamav2 generated 3136 tokens with 117.27953013576354 tps (with warm up) -llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 3200 tokens with 129.65265959777014 tps +formatron_llama3_8b_6pw_exl2_address_json_exllamav2 generated 1937 tokens with 81.90114209213694 tps (with warm up) +formatron_llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 92.93084312000893 tps +lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2 generated 2000 tokens with 44.49453610169303 tps (with warm up) +lm_format_enforcer_llama3_8b_6pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 92.60734707373982 tps +formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 712 tokens with 62.301888356328774 tps (with warm up) +formatron_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 94.1753542487434 tps +lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 generated 1000 tokens with 50.55230703666647 tps (with warm up) +lm_format_enforcer_llama3_8b_6pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 93.80809481125524 tps +formatron_llama3_8b_6pw_exl2_orders_json_exllamav2 generated 3282 tokens with 71.34696208336447 tps (with warm up) +formatron_llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 94.47558818852826 tps +lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2 generated 4000 tokens with 44.2748717455552 tps (with warm up) +lm_format_enforcer_llama3_8b_6pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 94.47549746230774 tps +formatron_llama2_7b_4pw_exl2_address_json_exllamav2 generated 1895 tokens with 116.77613285626795 tps (with warm up) +formatron_llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 135.14145208580499 tps +lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2 generated 2000 tokens with 122.83497659799589 tps (with warm up) +lm_format_enforcer_llama2_7b_4pw_exl2_address_json_exllamav2 unconstrained generated 2000 tokens with 135.06497241415678 tps +formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 749 tokens with 94.48129844732799 tps (with warm up) +formatron_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 135.2508753195326 tps +lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 generated 1000 tokens with 131.53638682928263 tps (with warm up) +lm_format_enforcer_llama2_7b_4pw_exl2_linkedlist_json_exllamav2 unconstrained generated 1000 tokens with 135.1702746403635 tps +formatron_llama2_7b_4pw_exl2_orders_json_exllamav2 generated 3672 tokens with 113.24874134695554 tps (with warm up) +formatron_llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 131.19728962397383 tps +lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2 generated 4000 tokens with 121.46458095557651 tps (with warm up) +lm_format_enforcer_llama2_7b_4pw_exl2_orders_json_exllamav2 unconstrained generated 4000 tokens with 131.21070982754418 tps diff --git a/benchmarks/readme.md b/benchmarks/readme.md index 72c52568..74d389b5 100644 --- a/benchmarks/readme.md +++ b/benchmarks/readme.md @@ -29,27 +29,25 @@ Default vllm setting are used. | Llama2-7B(fp16) | linkedlist_json | 46.51 | 47.54 | 0.46 | | Llama2-7B(fp16) | order_json | 45.71 | 46.68 | 0.46 | ## Exllamav2 -Default exllamav2 setting are used. The inferior performance of exllamav2 integration -can be attributed to the fact that `Exllamav2Filter` requires the implementation to return -a set of allowed tokens, and constructing a large set is very slow in Python. +Default exllamav2 setting are used. -| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms | -|------------------------|-----------------|---------------------------------|---------------------|-------------------------| -| Llama3-8B(6.0bpw-exl2) | address_json | 81.76 | 91.94 | 1.36 | -| Llama3-8B(6.0bpw-exl2) | linkedlist_json | 73.73 | 92.93 | 2.82 | -| Llama3-8B(6.0bpw-exl2) | order_json | 79.11 | 93.47 | 1.96 | -| Llama2-7B(4.0bpw-exl2) | address_json | 123.71 | 133.38 | 0.55 | -| Llama2-7B(4.0bpw-exl2) | linkedlist_json | 80.05 | 132.20 | 4.90 | -| Llama2-7B(4.0bpw-exl2) | order_json | 117.28 | 129.65 | 0.82 | +| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms | +|------------------------|-----------------|-------------------------------------------------|----------------------------------------------------------| +| Llama3-8B(6.0bpw-exl2) | address_json | 1.4 | 11.7 | +| Llama3-8B(6.0bpw-exl2) | linkedlist_json | 5.4 | 9.1 | +| Llama3-8B(6.0bpw-exl2) | order_json | 3.4 | 12.1 | +| Llama2-7B(4.0bpw-exl2) | address_json | 1.2 | 0.73 | +| Llama2-7B(4.0bpw-exl2) | linkedlist_json | 3.2 | 0.20 | +| Llama2-7B(4.0bpw-exl2) | order_json | 1.2 | 0.60 | ## Transformers Default transformers setting with flash attention v2 enabled. -| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms | -|-----------------|-----------------|---------------------------------|---------------------|-------------------------| -| Llama3-8B(bf16) | address_json | 37.39 | 38.71 | 0.91 | -| Llama3-8B(bf16) | linkedlist_json | 37.25 | 38.65 | 0.98 | -| Llama3-8B(bf16) | order_json | 36.73 | 38.11 | 0.99 | -| Llama2-7B(fp16) | address_json | 41.30 | 42.14 | 0.48 | -| Llama2-7B(fp16) | linkedlist_json | 40.75 | 41.91 | 0.68 | -| Llama2-7B(fp16) | order_json | 39.70 | 40.41 | 0.44 | \ No newline at end of file +| model | schema | Formatron overhead per token(with warm-up) / ms | lm format enforcer overhead(with warm-up) per token / ms | +|-----------------|-----------------|-------------------------------------------------|----------------------------------------------------------| +| Llama3-8B(bf16) | address_json | 0.65 | 9.3 | +| Llama3-8B(bf16) | linkedlist_json | 0.70 | 3.5 | +| Llama3-8B(bf16) | order_json | 0.69 | 6.1 | +| Llama2-7B(fp16) | address_json | 0.41 | 1.4 | +| Llama2-7B(fp16) | linkedlist_json | 0.54 | 0.58 | +| Llama2-7B(fp16) | order_json | 0.44 | 0.96 | diff --git a/benchmarks/transformers_json.py b/benchmarks/transformers_json.py index 5c1f6843..1a9988fd 100644 --- a/benchmarks/transformers_json.py +++ b/benchmarks/transformers_json.py @@ -5,13 +5,14 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer +from benchmarks.utils import load_address, load_linkedlist, load_orders, address_lfe, linked_list_lfe, order_lfe from utils import Order from grammar_generators.json_generator import JsonGenerator from integrations.transformers import create_formatter_logits_processor_list, FormattersLogitsProcessor from test_grammar_gen import LinkedList from utils import BenchResult, Context, Address, log from formatron.formatter import FormatterBuilder - +from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn def get_llama3_8b_tokenizer_and_model(): model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k", @@ -36,16 +37,25 @@ def get_address_schema(): f.append_line(f"{f.schema(Address, JsonGenerator(), capture_name='json')}") return create_formatter_logits_processor_list(tokenizer, f) +def lfe_address_prefix(): + return build_transformers_prefix_allowed_tokens_fn(tokenizer, address_lfe) + def get_linkedlist_schema(): f = FormatterBuilder() f.append_line(f"{f.schema(LinkedList, JsonGenerator(), capture_name='json')}") return create_formatter_logits_processor_list(tokenizer, f) +def lfe_linkedlist_prefix(): + return build_transformers_prefix_allowed_tokens_fn(tokenizer, linked_list_lfe) + def get_order_schema(): f = FormatterBuilder() f.append_line(f"{f.schema(Order, JsonGenerator(), capture_name='json')}") return create_formatter_logits_processor_list(tokenizer, f) +def lfe_order_prefix(): + return build_transformers_prefix_allowed_tokens_fn(tokenizer, order_lfe) + def execute(): prompts = [ f"{system_prompt}{inputs[context.index]}{tail}", @@ -53,8 +63,12 @@ def execute(): prompts = tokenizer(prompts, return_tensors='pt').to(model.device) input_len = prompts.input_ids.shape[-1] context.index+=1 - outputs = model.generate(**prompts, logits_processor=logits_processor, - max_new_tokens=max_new_tokens) + if logits_processor is not None: + outputs = model.generate(**prompts, logits_processor=logits_processor, + max_new_tokens=max_new_tokens) + else: + outputs = model.generate(**prompts, prefix_allowed_tokens_fn=prefix_fn, + max_new_tokens=max_new_tokens) context.tokens += outputs.shape[-1]-input_len l = logits_processor if l and isinstance(l[0], FormattersLogitsProcessor): @@ -72,7 +86,11 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): result.t1 = context.tokens context.index = 0 context.tokens = 0 - logits_processor.clear() + if logits_processor is not None: + logits_processor.clear() + else: + global prefix_fn + prefix_fn = None result.s2 = (timeit(func, number=len(inputs))) result.t2 = context.tokens log(bench_name, result, f) @@ -91,16 +109,26 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): model, tokenizer = get_llama3_8b_tokenizer_and_model() model.eval() max_new_tokens = 50 - inputs = json.load(open("address.json"))["sentences"] + inputs = load_address() + prefix_fn = None logits_processor = get_address_schema() - bench(data,context,execute, "llama3_8b_address_json", f) - inputs = json.load(open("linkedlist.json"))["sentences"] + bench(data,context,execute, "formatron_llama3_8b_address_json", f) + logits_processor = None + prefix_fn = lfe_address_prefix() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_address_json", f) + inputs = load_linkedlist() logits_processor = get_linkedlist_schema() max_new_tokens = 200 - bench(data,context,execute, "llama3_8b_linkedlist_json", f) - inputs = json.load(open("orders.json"))["orders"] + bench(data,context,execute, "formatron_llama3_8b_linkedlist_json", f) + logits_processor = None + prefix_fn = lfe_linkedlist_prefix() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_linkedlist_json", f) + inputs = load_orders() logits_processor = get_order_schema() - bench(data, context, execute, "llama3_8b_order_json", f) + bench(data, context, execute, "formatron_llama3_8b_order_json", f) + logits_processor = None + prefix_fn = lfe_order_prefix() + bench(data, context, execute, "lm_format_enforcer_llama3_8b_order_json", f) system_prompt = """[INST] You are a helpful AI assistant for information extraction. @@ -113,14 +141,23 @@ def bench(result:BenchResult, context:Context,func, bench_name:str, f): model, tokenizer = get_llama2_7b_tokenizer_and_model() model.eval() max_new_tokens = 50 - inputs = json.load(open("address.json"))["sentences"] + inputs = load_address() logits_processor = get_address_schema() - bench(data, context, execute, "llama2_7b_address_json", f) + bench(data, context, execute, "formatron_llama2_7b_address_json", f) + logits_processor = None + prefix_fn = lfe_address_prefix() + bench(data, context, execute, "lm_format_enforcer_llama2_7b_address_json", f) max_new_tokens = 30 - inputs = json.load(open("linkedlist.json"))["sentences"] + inputs = load_linkedlist() logits_processor = get_linkedlist_schema() max_new_tokens = 200 - bench(data, context, execute, "llama2_7b_linkedlist_json", f) - inputs = json.load(open("orders.json"))["orders"] + bench(data, context, execute, "formatron_llama2_7b_linkedlist_json", f) + logits_processor = None + prefix_fn = lfe_linkedlist_prefix() + bench(data, context, execute, "lm_format_enforcer_llama2_7b_linkedlist_json", f) + inputs = load_orders() logits_processor = get_order_schema() - bench(data, context, execute, "llama2_7b_order_json", f) \ No newline at end of file + bench(data, context, execute, "formatron_llama2_7b_order_json", f) + logits_processor = None + prefix_fn = lfe_order_prefix() + bench(data, context, execute, "lm_format_enforcer_llama2_7b_order_json", f) \ No newline at end of file diff --git a/benchmarks/transformers_json.txt b/benchmarks/transformers_json.txt index 50835905..06090d7a 100644 --- a/benchmarks/transformers_json.txt +++ b/benchmarks/transformers_json.txt @@ -1,12 +1,24 @@ -llama3_8b_address_json generated 786 tokens with 37.39219380500281 tps (with warm up) -llama3_8b_address_json unconstrained generated 1000 tokens with 38.71077301668357 tps -llama3_8b_linkedlist_json generated 887 tokens with 37.24889710198721 tps (with warm up) -llama3_8b_linkedlist_json unconstrained generated 1236 tokens with 38.65240435557168 tps -llama3_8b_order_json generated 3505 tokens with 36.733292562686295 tps (with warm up) -llama3_8b_order_json unconstrained generated 3969 tokens with 38.1100638135264 tps -llama2_7b_address_json generated 931 tokens with 41.29915523229768 tps (with warm up) -llama2_7b_address_json unconstrained generated 1000 tokens with 42.14071078188422 tps -llama2_7b_linkedlist_json generated 220 tokens with 40.7547980187885 tps (with warm up) -llama2_7b_linkedlist_json unconstrained generated 4000 tokens with 41.914465154549895 tps -llama2_7b_order_json generated 3973 tokens with 39.69588647270487 tps (with warm up) -llama2_7b_order_json unconstrained generated 4000 tokens with 40.41234834607258 tps +formatron_llama3_8b_address_json generated 786 tokens with 37.90142173165388 tps (with warm up) +formatron_llama3_8b_address_json unconstrained generated 1000 tokens with 38.85116671317693 tps +lm_format_enforcer_llama3_8b_address_json generated 899 tokens with 28.438352529535496 tps (with warm up) +lm_format_enforcer_llama3_8b_address_json unconstrained generated 1000 tokens with 38.73050778002801 tps +formatron_llama3_8b_linkedlist_json generated 887 tokens with 37.74593320321761 tps (with warm up) +formatron_llama3_8b_linkedlist_json unconstrained generated 1236 tokens with 38.7681057952068 tps +lm_format_enforcer_llama3_8b_linkedlist_json generated 1015 tokens with 34.091645039329364 tps (with warm up) +lm_format_enforcer_llama3_8b_linkedlist_json unconstrained generated 1236 tokens with 38.75929416634396 tps +formatron_llama3_8b_order_json generated 3505 tokens with 37.233423883274135 tps (with warm up) +formatron_llama3_8b_order_json unconstrained generated 3969 tokens with 38.21995325408202 tps +lm_format_enforcer_llama3_8b_order_json generated 3545 tokens with 30.99588997163373 tps (with warm up) +lm_format_enforcer_llama3_8b_order_json unconstrained generated 3969 tokens with 38.218372294863066 tps +formatron_llama2_7b_address_json generated 974 tokens with 41.56895625319404 tps (with warm up) +formatron_llama2_7b_address_json unconstrained generated 1000 tokens with 42.287667110185744 tps +lm_format_enforcer_llama2_7b_address_json generated 1000 tokens with 39.86783805027324 tps (with warm up) +lm_format_enforcer_llama2_7b_address_json unconstrained generated 1000 tokens with 42.28065331989889 tps +formatron_llama2_7b_linkedlist_json generated 220 tokens with 41.157716280397565 tps (with warm up) +formatron_llama2_7b_linkedlist_json unconstrained generated 4000 tokens with 42.088925145411324 tps +lm_format_enforcer_llama2_7b_linkedlist_json generated 3428 tokens with 41.081280656213565 tps (with warm up) +lm_format_enforcer_llama2_7b_linkedlist_json unconstrained generated 4000 tokens with 42.08850708467392 tps +formatron_llama2_7b_order_json generated 3948 tokens with 39.96876102345699 tps (with warm up) +formatron_llama2_7b_order_json unconstrained generated 4000 tokens with 40.680081826758254 tps +lm_format_enforcer_llama2_7b_order_json generated 4000 tokens with 39.143187642063694 tps (with warm up) +lm_format_enforcer_llama2_7b_order_json unconstrained generated 4000 tokens with 40.67817922991219 tps diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 2dcf62ca..d06adc6c 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,7 +1,11 @@ +import gc +import json from dataclasses import dataclass from typing import Optional -from formatron.schemas.pydantic import ClassSchema +import torch +from formatron.schemas.pydantic import ClassSchema +from lmformatenforcer import JsonSchemaParser class Address(ClassSchema): street: str @@ -10,10 +14,14 @@ class Address(ClassSchema): postal_code: str country: str +address_lfe = JsonSchemaParser(Address.model_json_schema()) + class LinkedList(ClassSchema): value: int next: Optional["LinkedList"] +linked_list_lfe = JsonSchemaParser(LinkedList.model_json_schema()) + class OrderItem(ClassSchema): product_id: int variant_id: int @@ -35,6 +43,8 @@ class Order(ClassSchema): total_amount: float status: str +order_lfe = JsonSchemaParser(Order.model_json_schema()) + @dataclass class BenchResult: t1:int @@ -53,4 +63,18 @@ def log(func_name:str, data:BenchResult,f): f" {data.t2 / data.s2} tps\n") print(a) print(b) - f.writelines([a,b]) \ No newline at end of file + f.writelines([a,b]) + + +def load_address()->list[str]: + return json.load(open("address.json"))["sentences"] + +def load_linkedlist()->list[str]: + return json.load(open("linkedlist.json"))["sentences"] + +def load_orders()->list[str]: + return json.load(open("orders.json"))["orders"] + +def force_gc(): + torch.cuda.empty_cache() + gc.collect() \ No newline at end of file