Skip to content

Commit

Permalink
Bug fixes&update benchmarks
Browse files Browse the repository at this point in the history
- correctly revert huggingface bytelevel vocabulary preprocessor

- vllm integration bug fix on batch of different sequence length
  • Loading branch information
Dan-wanna-M committed Aug 16, 2024
1 parent fb1fea3 commit 0b4abc9
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 29 deletions.
12 changes: 6 additions & 6 deletions benchmarks/result.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ Default vllm setting are used.

| model | schema | constrained(with warm-up) / tps | unconstrained / tps | overhead per token / ms |
|-----------------|-----------------|---------------------------------|---------------------|-------------------------|
| Llama3-8B(bf16) | address_json | 40.85 | 41.93 | 0.63 |
| Llama3-8B(bf16) | linkedlist_json | 40.54 | 41.83 | 0.76 |
| Llama3-8B(bf16) | order_json | 40.03 | 41.45 | 0.86 |
| Llama2-7B(fp16) | address_json | 46.54 | 47.51 | 0.44 |
| Llama2-7B(fp16) | linkedlist_json | 46.50 | 47.51 | 0.46 |
| Llama2-7B(fp16) | order_json | 45.69 | 46.65 | 0.45 |
| Llama3-8B(bf16) | address_json | 41.10 | 41.97 | 0.50 |
| Llama3-8B(bf16) | linkedlist_json | 40.80 | 41.91 | 0.65 |
| Llama3-8B(bf16) | order_json | 40.24 | 41.52 | 0.77 |
| Llama2-7B(fp16) | address_json | 46.92 | 47.69 | 0.34 |
| Llama2-7B(fp16) | linkedlist_json | 46.80 | 47.71 | 0.41 |
| Llama2-7B(fp16) | order_json | 45.96 | 46.84 | 0.41 |
## Exllamav2
Default exllamav2 setting are used.

Expand Down
24 changes: 12 additions & 12 deletions benchmarks/vllm_json_bench.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
llama3_8b_vllm_address generated 1085 tokens with 40.8469097124382 tps (with warm up)
llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.92947897884734 tps
llama3_8b_linkedlist generated 1252 tokens with 40.53841940713149 tps (with warm up)
llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.83142077141228 tps
llama3_8b_orders generated 4266 tokens with 40.03391057585299 tps (with warm up)
llama3_8b_orders unconstrained generated 4595 tokens with 41.45157671971401 tps
llama2_7b_vllm_address generated 1755 tokens with 46.537152877353094 tps (with warm up)
llama2_7b_vllm_address unconstrained generated 1702 tokens with 47.51143248721739 tps
llama2_7b_linkedlist generated 1918 tokens with 46.50518276029844 tps (with warm up)
llama2_7b_linkedlist unconstrained generated 1903 tokens with 47.51484558427169 tps
llama2_7b_orders generated 5073 tokens with 45.68708609099978 tps (with warm up)
llama2_7b_orders unconstrained generated 5110 tokens with 46.65271793400345 tps
llama3_8b_vllm_address generated 1085 tokens with 41.09999767701784 tps (with warm up)
llama3_8b_vllm_address unconstrained generated 1526 tokens with 41.968252244152325 tps
llama3_8b_linkedlist generated 1252 tokens with 40.797074516684404 tps (with warm up)
llama3_8b_linkedlist unconstrained generated 1389 tokens with 41.91172085238663 tps
llama3_8b_orders generated 4266 tokens with 40.24421333602719 tps (with warm up)
llama3_8b_orders unconstrained generated 4595 tokens with 41.520316679742486 tps
llama2_7b_vllm_address generated 1175 tokens with 46.92316431466862 tps (with warm up)
llama2_7b_vllm_address unconstrained generated 1751 tokens with 47.68570809701806 tps
llama2_7b_linkedlist generated 1288 tokens with 46.79789280862108 tps (with warm up)
llama2_7b_linkedlist unconstrained generated 1945 tokens with 47.71323389337299 tps
llama2_7b_orders generated 4207 tokens with 45.96312899147878 tps (with warm up)
llama2_7b_orders unconstrained generated 5112 tokens with 46.83597327628955 tps
27 changes: 21 additions & 6 deletions src/formatron/integrations/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import typing
from functools import lru_cache


def _multiple_replace(replacements, text):
Expand All @@ -17,15 +18,12 @@ def _autodetect_processors(vocab: typing.Dict[str, int]):
llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys())
underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2
g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2
c_present = any(i.find('\u010A') != -1 for i in vocab.keys())
if llama_present:
result.add("<0xHH>")
if underscore_present:
result.add("sentencepiece")
elif g_present:
result.add("dot_G")
if c_present:
result.add("dot_C")
return result


Expand All @@ -36,9 +34,7 @@ def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[bytes,
if i == "sentencepiece":
old_char_to_new_char["\u2581".encode("UTF-8")] = b" "
elif i == "dot_G":
old_char_to_new_char["\u0120".encode("UTF-8")] = b" "
elif i == "dot_C":
old_char_to_new_char["\u010A".encode("UTF-8")] = b"\n"
old_char_to_new_char.update(huggingface_bytelevel_decoder())
elif i == "<0xHH>":
for j in range(256):
old_char_to_new_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j])
Expand All @@ -51,3 +47,22 @@ def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[bytes,
new_k = _multiple_replace(old_char_to_new_char, k)
new_vocab[new_k] = token_id
return new_vocab


@lru_cache()
def huggingface_bytelevel_decoder():
"""
I hate legacy code.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n).encode("UTF-8") for n in cs]
for i in range(len(bs)):
bs[i] = bytes([bs[i]])
return dict(zip(cs, bs))
5 changes: 3 additions & 2 deletions src/formatron/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@ def __call__(self, prompt, generated_tokens, logits):
self._to_next_batch_step()
result = next(self._iter)
self._last_input_id_length += 1

formatter, _ = result
while formatter.is_completed():
formatter, _ = next(self._iter)
if len(generated_tokens) != 0: # accept new token
input_id = generated_tokens[-1]
if input_id != self._eos_token_id:
formatter.accept_token(input_id)

if formatter.is_completed():
logits[:] = float("-inf")
logits[self._eos_token_id] = 0.0
logits[self._eos_token_id] = 1000
return logits
formatter.compute_allowed_tokens()
logits = formatter.mask_logits(logits)
Expand Down
2 changes: 1 addition & 1 deletion tests/snapshots/snap_test_vllm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

snapshots['test_vllm_integration 1'] = "Prompt: 'Hello, my name is', Generated text: 'definitely vllm!\\n'"

snapshots['test_vllm_integration 2'] = "Prompt: 'The future of AI is', Generated text: 'vllm for sure!\\n'"
snapshots['test_vllm_integration 2'] = "Prompt: 'The future of AI is', Generated text: '强大的【VLLM】!\\n'"
4 changes: 2 additions & 2 deletions tests/test_vllm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def test_vllm_integration(snapshot):
f = FormatterBuilder()
f.append_line("definitely vllm!")
f2 = FormatterBuilder()
f2.append_line("vllm for sure!")
f2.append_line("强大的【VLLM】!")
logits_processor = create_formatters_logits_processor(llm, [f, f2])
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, logits_processors=[logits_processor])
sampling_params = SamplingParams(max_tokens=50,temperature=0.8, top_p=0.95, logits_processors=[logits_processor])
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down

0 comments on commit 0b4abc9

Please sign in to comment.