Skip to content

Commit

Permalink
Fix everything
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 7, 2024
1 parent 624cb01 commit a3bb7ce
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
]
description = "Formatron empowers everyone to control the output format of language models with minimal overhead."
readme = "README.md"
dependencies = ["pydantic>=2","kbnf>=0.2.3"]
dependencies = ["pydantic>=2","kbnf>=0.2.4"]
license = {file = "LICENSE"}
keywords = ["deep learning", "language model", "guided generation", "structured generation","constrained decoding"]
requires-python = ">=3.10"
Expand Down
2 changes: 1 addition & 1 deletion src/formatron/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def str(self, *, stop: typing.Union[str, list[str]] = None,
capture_regex = ".*"
nonterminal_regex = "#'.*'"
else:
capture_regex = f".*?(?:{'|'.join(map(re.escape, stop + not_contain))})"
capture_regex = f".*?(?:{'|'.join(map(re.escape, stop))})"
excepted = f"{nonterminal}_excepted"
end = f"({'|'.join(map(repr, stop))})" if stop else ""
nonterminal_regex = f"except!({excepted}){end}"
Expand Down
14 changes: 9 additions & 5 deletions src/formatron/integrations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ def _multiple_replace(replacements, text):

def _autodetect_processors(vocab:typing.Dict[str, int]):
result = set()
space_present = any(i.find(' ')!=-1 for i in vocab.keys())
llama_present = any(i.find('<0xF0>')!=-1 for i in vocab.keys())
underscore_present = any(i.find('\u2581')!=-1 for i in vocab.keys())
g_present = any(i.find('\u0120')!=-1 for i in vocab.keys())
underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581')!=-1]) / len(vocab)) > 0.2
g_present = (len([1 for i in vocab.keys() if i.find('\u0120')!=-1]) / len(vocab)) > 0.2
c_present = any(i.find('\u010A') != -1 for i in vocab.keys())
if llama_present:
result.add("<0xHH>")
if not space_present and underscore_present:
if underscore_present:
result.add("sentencepiece")
elif not space_present and g_present:
elif g_present:
result.add("dot_G")
if c_present:
result.add("dot_C")
return result

def get_original_characters(vocab:typing.Dict[str, int]) -> typing.Dict[bytes, int]:
Expand All @@ -32,6 +34,8 @@ def get_original_characters(vocab:typing.Dict[str, int]) -> typing.Dict[bytes, i
old_char_to_new_char["\u2581".encode("UTF-8")] = b" "
elif i == "dot_G":
old_char_to_new_char["\u0120".encode("UTF-8")] = b" "
elif i == "dot_C":
old_char_to_new_char["\u010A".encode("UTF-8")] = b"\n"
elif i == "<0xHH>":
for j in range(256):
old_char_to_new_char[("<0x"+f"{j:02x}".upper()+">").encode("UTF-8")] = bytes([j])
Expand Down
52 changes: 26 additions & 26 deletions src/formatron/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,6 @@
from integrations._utils import get_original_characters


def create_engine_vocabulary(llm: LLM) -> kbnf.Vocabulary:
"""
Create a vocabulary for the KBNF engine.
"""
tokenizer = llm.get_tokenizer()
vocab = tokenizer.get_vocab()
new_vocab = get_original_characters(tokenizer, vocab)
return kbnf.Vocabulary({v: kbnf.Token(k.encode("utf-8")) for k, v in new_vocab.items()},
{v: k for k, v in new_vocab.items()})


def create_formatters_logits_processor(llm: LLM,
formatter_builders: typing.Sequence[FormatterBuilder] | FormatterBuilder,
configs: typing.Sequence[EngineGenerationConfig] = None) \
-> "FormattersLogitsProcessor":
"""
Create a formatter logits processor.
"""
tokenizer = llm.get_tokenizer()
vocab = create_engine_vocabulary(llm)
if not isinstance(formatter_builders, collections.abc.Sequence):
formatter_builders = [formatter_builders]
formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) for i in formatter_builders]
return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)


class FormattersLogitsProcessor:
"""
Logit processor that uses formatters to mask batch logits.
Expand Down Expand Up @@ -92,3 +66,29 @@ def __call__(self, prompt, generated_tokens, logits):
formatter.compute_allowed_tokens()
logits = formatter.mask_logits(logits)
return logits


def create_engine_vocabulary(llm: LLM) -> kbnf.Vocabulary:
"""
Create a vocabulary for the KBNF engine.
"""
tokenizer = llm.get_tokenizer()
vocab = tokenizer.get_vocab()
new_vocab = get_original_characters(vocab)
return kbnf.Vocabulary({v: kbnf.Token(k) for k, v in new_vocab.items()}, {
v:k for k,v in vocab.items()})


def create_formatters_logits_processor(llm: LLM,
formatter_builders: typing.Sequence[FormatterBuilder] | FormatterBuilder,
configs: typing.Sequence[EngineGenerationConfig] = None) \
-> FormattersLogitsProcessor:
"""
Create a formatter logits processor.
"""
tokenizer = llm.get_tokenizer()
vocab = create_engine_vocabulary(llm)
if not isinstance(formatter_builders, collections.abc.Sequence):
formatter_builders = [formatter_builders]
formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) for i in formatter_builders]
return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
3 changes: 3 additions & 0 deletions tests/snapshots/snap_test_exllamav2_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@

snapshots['test_exllamav2_integration 1'] = '''Hello, cats! Hello, Exllamav2!
'''

snapshots['test_exllamav2_utf_8 1'] = '''Hello, cats! 你好,土豆!
'''
13 changes: 5 additions & 8 deletions tests/snapshots/snap_test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
__regex_0_0 ::= #'[0-9]+';
__regex_1_0 ::= #'[a-z]+';
__choice_ID_0 ::= __regex_0_0 | __regex_1_0;
__str_2_0_excepted ::= '\\\\.';
__str_2_0 ::= except!(__str_2_0_excepted)('\\\\.');
integer ::= #"-?(0|[1-9]\\\\d*)";
number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
Expand All @@ -32,7 +30,7 @@
__schema_json_0_weight ::= number;
__schema_json_0_name ::= string;
start ::= \'Today, I want to eat \' __choice_food_0 \'\\n\' "My food\'s ID is " __choice_ID_0 \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately. Let\'s " __str_2_0 \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __schema_json_0 \'\\n\';'''
start ::= \'Today, I want to eat \' __choice_food_0 \'\\n\' "My food\'s ID is " __choice_ID_0 \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __schema_json_0 \'\\n\';'''

snapshots['test_formatter 2'] = '''Today, I want to eat orange
My food's ID is a.
Expand Down Expand Up @@ -106,14 +104,13 @@
}
}

snapshots['test_formatter_str 1'] = '''__str_0_0_excepted ::= '\\\\.' | '!';
__str_0_0 ::= except!(__str_0_0_excepted)('\\\\.'|'!');
snapshots['test_formatter_str 1'] = '''__str_0_0_excepted ::= '.';
__str_0_0 ::= except!(__str_0_0_excepted)('.');
start ::= __str_0_0 '\\n';'''

snapshots['test_formatter_str 2'] = '''🤗"
"I am the father of three beautiful girls, all of whom are also famous in the world of martial arts," said Van, "I want to thank you for taking care of my girls and my mother, and I hope that you will continue to support me in my quest to bring happiness to the world, because I believe that we are all connected through love and kindness, and that\'s why I am happy to have met you today and hope that we can make a new life together as a family, one where we can be happy together and love each other more than anything else in the world, even if we do not know it yet"
"I hope that this relationship will last forever, and that we can find happiness in our lives together, because we are both very special people who are willing to do anything for each other," said Van, "I wish you all the best in your future endeavors, and I look forward to seeing you again soon, and may God bless you all"
Van, a martial artist from the Philippines, is known for his passion for martial arts and has been involved in various martial arts competitions around the world, including the Asian Championships held in Japan, Thailand, Singapore, Malaysia, Hong Kong, and Taiwan'''
"I am the father of three beautiful girls, all of whom are also famous.
'''

snapshots['test_formatter_str 3'] = {
}
7 changes: 3 additions & 4 deletions tests/test_exllamav2_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from copy import deepcopy

from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
from exllamav2.generator import ExLlamaV2DynamicGenerator

from formatter import FormatterBuilder
from integrations.exllamav2 import create_formatter_filter
Expand Down Expand Up @@ -37,8 +35,9 @@ def test_exllamav2_utf_8(snapshot):
cache = ExLlamaV2Cache(model, max_seq_len = 65536, lazy = True)
model.load_autosplit(cache, progress = True)
tokenizer = ExLlamaV2Tokenizer(config)

f = FormatterBuilder()
f.append_line("你好,羊驼!")
f.append_line("你好,土豆!")
exllama_filter = create_formatter_filter(model, tokenizer, f)
generator = ExLlamaV2DynamicGenerator(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_formatter(snapshot):
def test_formatter_str(snapshot):
FormatterBuilder._formatter_builder_counter = 0
f = FormatterBuilder()
f.append_line(f"{f.str(stop=['.','!', ','])}")
f.append_line(f"{f.str(stop=['.'])}")
model = RWKV("assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16')
pipeline = integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f)
np.random.seed(42)
Expand Down

0 comments on commit a3bb7ce

Please sign in to comment.