From 0b4abc919c803ca3dffcf50fca74c8e66b54749e Mon Sep 17 00:00:00 2001 From: Huanghe Date: Fri, 16 Aug 2024 18:14:44 -0500 Subject: [PATCH] Bug fixes&update benchmarks - correctly revert huggingface bytelevel vocabulary preprocessor - vllm integration bug fix on batch of different sequence length --- benchmarks/result.md | 12 ++++----- benchmarks/vllm_json_bench.txt | 24 ++++++++--------- src/formatron/integrations/_utils.py | 27 ++++++++++++++----- src/formatron/integrations/vllm.py | 5 ++-- tests/snapshots/snap_test_vllm_integration.py | 2 +- tests/test_vllm_integration.py | 4 +-- 6 files changed, 45 insertions(+), 29 deletions(-) diff --git a/benchmarks/result.md b/benchmarks/result.md index 7983fc84..996356e8 100644 --- a/benchmarks/result.md +++ b/benchmarks/result.md @@ -22,12 +22,12 @@ Default vllm setting are used. | model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms | |-----------------|-----------------|---------------------------------|---------------------|-------------------------| -| Llama3-8B(bf16) | address_json | 40.85 | 41.93 | 0.63 | -| Llama3-8B(bf16) | linkedlist_json | 40.54 | 41.83 | 0.76 | -| Llama3-8B(bf16) | order_json | 40.03 | 41.45 | 0.86 | -| Llama2-7B(fp16) | address_json | 46.54 | 47.51 | 0.44 | -| Llama2-7B(fp16) | linkedlist_json | 46.50 | 47.51 | 0.46 | -| Llama2-7B(fp16) | order_json | 45.69 | 46.65 | 0.45 | +| Llama3-8B(bf16) | address_json | 41.10 | 41.97 | 0.50 | +| Llama3-8B(bf16) | linkedlist_json | 40.80 | 41.91 | 0.65 | +| Llama3-8B(bf16) | order_json | 40.24 | 41.52 | 0.77 | +| Llama2-7B(fp16) | address_json | 46.92 | 47.69 | 0.34 | +| Llama2-7B(fp16) | linkedlist_json | 46.80 | 47.71 | 0.41 | +| Llama2-7B(fp16) | order_json | 45.96 | 46.84 | 0.41 | ## Exllamav2 Default exllamav2 setting are used. diff --git a/benchmarks/vllm_json_bench.txt b/benchmarks/vllm_json_bench.txt index 6c9d876d..8b34fc11 100644 --- a/benchmarks/vllm_json_bench.txt +++ b/benchmarks/vllm_json_bench.txt @@ -1,12 +1,12 @@ -llama3_8b_vllm_address generated 1085 tokens with 40.8469097124382 tps (with warm up) -llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.92947897884734 tps -llama3_8b_linkedlist generated 1252 tokens with 40.53841940713149 tps (with warm up) -llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.83142077141228 tps -llama3_8b_orders generated 4266 tokens with 40.03391057585299 tps (with warm up) -llama3_8b_orders unconstrained generated 4595 tokens with 41.45157671971401 tps -llama2_7b_vllm_address generated 1755 tokens with 46.537152877353094 tps (with warm up) -llama2_7b_vllm_address unconstrained generated 1702 tokens with 47.51143248721739 tps -llama2_7b_linkedlist generated 1918 tokens with 46.50518276029844 tps (with warm up) -llama2_7b_linkedlist unconstrained generated 1903 tokens with 47.51484558427169 tps -llama2_7b_orders generated 5073 tokens with 45.68708609099978 tps (with warm up) -llama2_7b_orders unconstrained generated 5110 tokens with 46.65271793400345 tps +llama3_8b_vllm_address generated 1085 tokens with 41.09999767701784 tps (with warm up) +llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.968252244152325 tps +llama3_8b_linkedlist generated 1252 tokens with 40.797074516684404 tps (with warm up) +llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.91172085238663 tps +llama3_8b_orders generated 4266 tokens with 40.24421333602719 tps (with warm up) +llama3_8b_orders unconstrained generated 4595 tokens with 41.520316679742486 tps +llama2_7b_vllm_address generated 1175 tokens with 46.92316431466862 tps (with warm up) +llama2_7b_vllm_address unconstrained generated 1751 tokens with 47.68570809701806 tps +llama2_7b_linkedlist generated 1288 tokens with 46.79789280862108 tps (with warm up) +llama2_7b_linkedlist unconstrained generated 1945 tokens with 47.71323389337299 tps +llama2_7b_orders generated 4207 tokens with 45.96312899147878 tps (with warm up) +llama2_7b_orders unconstrained generated 5112 tokens with 46.83597327628955 tps diff --git a/src/formatron/integrations/_utils.py b/src/formatron/integrations/_utils.py index ccb93753..44a383c5 100644 --- a/src/formatron/integrations/_utils.py +++ b/src/formatron/integrations/_utils.py @@ -1,5 +1,6 @@ import re import typing +from functools import lru_cache def _multiple_replace(replacements, text): @@ -17,15 +18,12 @@ def _autodetect_processors(vocab: typing.Dict[str, int]): llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys()) underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2 g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2 - c_present = any(i.find('\u010A') != -1 for i in vocab.keys()) if llama_present: result.add("<0xHH>") if underscore_present: result.add("sentencepiece") elif g_present: result.add("dot_G") - if c_present: - result.add("dot_C") return result @@ -36,9 +34,7 @@ def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[bytes, if i == "sentencepiece": old_char_to_new_char["\u2581".encode("UTF-8")] = b" " elif i == "dot_G": - old_char_to_new_char["\u0120".encode("UTF-8")] = b" " - elif i == "dot_C": - old_char_to_new_char["\u010A".encode("UTF-8")] = b"\n" + old_char_to_new_char.update(huggingface_bytelevel_decoder()) elif i == "<0xHH>": for j in range(256): old_char_to_new_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j]) @@ -51,3 +47,22 @@ def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[bytes, new_k = _multiple_replace(old_char_to_new_char, k) new_vocab[new_k] = token_id return new_vocab + + +@lru_cache() +def huggingface_bytelevel_decoder(): + """ + I hate legacy code. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n).encode("UTF-8") for n in cs] + for i in range(len(bs)): + bs[i] = bytes([bs[i]]) + return dict(zip(cs, bs)) \ No newline at end of file diff --git a/src/formatron/integrations/vllm.py b/src/formatron/integrations/vllm.py index 4a334626..351833e8 100644 --- a/src/formatron/integrations/vllm.py +++ b/src/formatron/integrations/vllm.py @@ -64,8 +64,9 @@ def __call__(self, prompt, generated_tokens, logits): self._to_next_batch_step() result = next(self._iter) self._last_input_id_length += 1 - formatter, _ = result + while formatter.is_completed(): + formatter, _ = next(self._iter) if len(generated_tokens) != 0: # accept new token input_id = generated_tokens[-1] if input_id != self._eos_token_id: @@ -73,7 +74,7 @@ def __call__(self, prompt, generated_tokens, logits): if formatter.is_completed(): logits[:] = float("-inf") - logits[self._eos_token_id] = 0.0 + logits[self._eos_token_id] = 1000 return logits formatter.compute_allowed_tokens() logits = formatter.mask_logits(logits) diff --git a/tests/snapshots/snap_test_vllm_integration.py b/tests/snapshots/snap_test_vllm_integration.py index bd4f7ca8..6a00919a 100644 --- a/tests/snapshots/snap_test_vllm_integration.py +++ b/tests/snapshots/snap_test_vllm_integration.py @@ -9,4 +9,4 @@ snapshots['test_vllm_integration 1'] = "Prompt: 'Hello, my name is', Generated text: 'definitely vllm!\\n'" -snapshots['test_vllm_integration 2'] = "Prompt: 'The future of AI is', Generated text: 'vllm for sure!\\n'" +snapshots['test_vllm_integration 2'] = "Prompt: 'The future of AI is', Generated text: '强大的【VLLM】!\\n'" diff --git a/tests/test_vllm_integration.py b/tests/test_vllm_integration.py index 3a07d99f..baa560ef 100644 --- a/tests/test_vllm_integration.py +++ b/tests/test_vllm_integration.py @@ -13,9 +13,9 @@ def test_vllm_integration(snapshot): f = FormatterBuilder() f.append_line("definitely vllm!") f2 = FormatterBuilder() - f2.append_line("vllm for sure!") + f2.append_line("强大的【VLLM】!") logits_processor = create_formatters_logits_processor(llm, [f, f2]) - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, logits_processors=[logits_processor]) + sampling_params = SamplingParams(max_tokens=50,temperature=0.8, top_p=0.95, logits_processors=[logits_processor]) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params)