Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit b70206d
Author: Huanghe <[email protected]>
Date:   Mon Nov 4 21:56:32 2024 -0600

    inconsistency fix

commit 30f672d
Author: Huanghe <[email protected]>
Date:   Mon Nov 4 21:55:04 2024 -0600

    Refine docs&types annotations

commit 911112b
Merge: a01d5ba 79255ff
Author: Huanghe <[email protected]>
Date:   Tue Oct 29 12:20:41 2024 -0500

    Merge pull request #21 from lukaszkolodziejczyk/vocab-processors

    Introduce vocabulary processors

commit 79255ff
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Tue Oct 29 12:23:07 2024 +0100

    docs & others

commit 3f54ad4
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Tue Oct 29 11:58:04 2024 +0100

    docs

commit 6d2b2d3
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Tue Oct 29 11:57:58 2024 +0100

    fix

commit 388d28d
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Tue Oct 29 11:32:39 2024 +0100

    vllm & exllamav2 integrations

commit 3b72ab5
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Wed Oct 23 13:48:55 2024 +0200

    typing

commit a45f09b
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Wed Oct 23 13:39:19 2024 +0200

    review

commit 26ff669
Author: Lukasz Kolodziejczyk <[email protected]>
Date:   Wed Oct 16 10:33:44 2024 +0200

    draft vocab processors
  • Loading branch information
Dan-wanna-M committed Nov 5, 2024
1 parent 9272d02 commit 7c18e17
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 90 deletions.
2 changes: 1 addition & 1 deletion src/formatron/integrations/RWKV.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from formatron.config import EngineGenerationConfig
from formatron.formatter import FormatterBuilder


__all__ = ["create_engine_vocabulary", "PIPELINE", "PIPELINE_ARGS"]
class PIPELINE_ARGS(rwkv.utils.PIPELINE_ARGS):
"""
A wrapper for the arguments of the pipeline of RWKV.
Expand Down
68 changes: 0 additions & 68 deletions src/formatron/integrations/_utils.py

This file was deleted.

25 changes: 19 additions & 6 deletions src/formatron/integrations/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,44 @@
from exllamav2.generator.base import ExLlamaV2Filter
from formatron.config import EngineGenerationConfig
from formatron.formatter import FormatterBase, FormatterBuilder
from formatron.integrations._utils import get_original_characters
from functools import lru_cache
from formatron.integrations.utils import get_original_characters


def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer) -> kbnf.Vocabulary:
__all__ = ["create_engine_vocabulary", "create_formatter_filter", "FormatterFilter"]
def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer,
vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
"""
Create a vocabulary for the KBNF engine.
Args:
tokenizer: The tokenizer.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
assert hasattr(tokenizer.tokenizer_model, "vocab"), (f"tokenizer({tokenizer})"
f" with tokenizer_model({tokenizer.tokenizer_model})"
f" does not have vocab attribute!")
vocab = {tokenizer.tokenizer_model.id_to_piece(
i): i for i in range(tokenizer.tokenizer_model.vocab_size())}
new_vocab = get_original_characters(vocab)
new_vocab = get_original_characters(vocab, vocab_processors)
return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
{v: k for k, v in vocab.items()})


def create_formatter_filter(model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer,
formatter_builder: FormatterBuilder,
engine_config: EngineGenerationConfig = None) -> ExLlamaV2Filter:
engine_config: EngineGenerationConfig = None,
vocab_processors: typing.Optional[list[typing.Callable]] = None) -> ExLlamaV2Filter:
"""
Create a formatter filter for the ExLlamaV2 engine.
Args:
model: The ExLlamaV2 model.
tokenizer: The ExLlamaV2 tokenizer.
formatter_builder: The formatter builder.
engine_config: The engine generation configuration.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
vocab = create_engine_vocabulary(tokenizer)
vocab = create_engine_vocabulary(tokenizer, vocab_processors)
f = formatter_builder.build(
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)))
return FormatterFilter(model, tokenizer, f, engine_config)
Expand Down
34 changes: 27 additions & 7 deletions src/formatron/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,39 @@

from formatron.config import EngineGenerationConfig
from formatron.formatter import FormatterBuilder, FormatterBase
from formatron.integrations._utils import get_original_characters
from formatron.integrations.utils import get_original_characters

__all__ = ["create_engine_vocabulary", "create_formatter_logits_processor", "create_formatter_logits_processor_list", "FormattersLogitsProcessor"]

def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase) -> kbnf.Vocabulary:
def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase,
vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
"""
Create a vocabulary for the KBNF engine.
Args:
tokenizer: The tokenizer.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
vocab = tokenizer.get_vocab()
new_vocab = get_original_characters(vocab)
new_vocab = get_original_characters(vocab, vocab_processors)
return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
{v: k for k, v in vocab.items()})


def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase,
formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
configs: typing.Sequence[EngineGenerationConfig] = None) -> LogitsProcessor:
configs: typing.Sequence[EngineGenerationConfig] = None,
vocab_processors: typing.Optional[list[typing.Callable]] = None) -> LogitsProcessor:
"""
Create a formatter logits processor.
Args:
tokenizer: The tokenizer.
formatter_builders: The formatter builders.
configs: The engine generation configurations.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
vocab = create_engine_vocabulary(tokenizer)
vocab = create_engine_vocabulary(tokenizer, vocab_processors)
if not isinstance(formatter_builders, collections.abc.Sequence):
formatter_builders = [formatter_builders]
formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
Expand All @@ -38,13 +51,20 @@ def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase,

def create_formatter_logits_processor_list(tokenizer: PreTrainedTokenizerBase,
formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
configs: typing.Sequence[EngineGenerationConfig] = None) \
configs: typing.Sequence[EngineGenerationConfig] = None,
vocab_processors: typing.Optional[list[typing.Callable]] = None) \
-> LogitsProcessorList:
"""
Create a formatter logits processor list.
Args:
tokenizer: The tokenizer.
formatter_builders: The formatter builders.
configs: The engine generation configurations.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
return LogitsProcessorList([create_formatter_logits_processor(tokenizer,
formatter_builders, configs)])
formatter_builders, configs, vocab_processors)])


class FormattersLogitsProcessor(LogitsProcessor):
Expand Down
93 changes: 93 additions & 0 deletions src/formatron/integrations/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
import typing
from functools import lru_cache

__all__ = ["get_original_characters", "update_vocab_0xHH", "update_vocab_sentencepiece", "update_vocab_dot_G"]

def _multiple_replace(replacements: typing.Dict[bytes, bytes], regex: re.Pattern[bytes], text: bytes) -> bytes:
# For each match, look-up corresponding value in dictionary
return regex.sub(lambda mo: replacements[mo.group()], text)


def get_original_characters(vocab: typing.Dict[str, int],
processors: typing.Optional[list[typing.Callable]] = None) -> typing.Dict[int, bytes]:
"""
Get a vocabulary of original characters unmangled to raw UTF-8 bytes by the provided processors.
Args:
vocab: The mangled vocabulary.
processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
old_char_to_new_char = {}
assert len(set(vocab.values())) == len(vocab), "Vocabulary contains duplicate token IDs!"
if processors is None:
processors = autodetect_processors(vocab)
for update_vocab in processors:
update_vocab(old_char_to_new_char)
# Create a regular expression from the dictionary keys with longest keys first to avoid conflicts
regex = re.compile(b"(%s)" % b"|".join(sorted(list(map(re.escape, old_char_to_new_char.keys())), key=lambda x: len(x), reverse=True)))
new_vocab = {}
for k in vocab:
token_id = vocab[k]
new_k = _multiple_replace(old_char_to_new_char, regex, k.encode("UTF-8"))
new_vocab[token_id] = new_k
return new_vocab


def autodetect_processors(vocab: typing.Dict[str, int]) -> typing.List[typing.Callable]:
"""
Autodetect vocabulary processors.
"""
result = []
llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys())
underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2
g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2
if llama_present:
result.append(update_vocab_0xHH)
if underscore_present:
result.append(update_vocab_sentencepiece)
elif g_present:
result.append(update_vocab_dot_G)
return result


def update_vocab_0xHH(token_to_char: typing.Dict[bytes, bytes]):
"""
Vocabulary processor for <0xHH> tokens (used in llama tokenizers)
"""
for j in range(256):
token_to_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j])


def update_vocab_sentencepiece(token_to_char: typing.Dict[bytes, bytes]):
"""
Vocabulary processor for ▁ token (used in sentencepiece tokenizers)
"""
token_to_char["\u2581".encode("UTF-8")] = b" "


def update_vocab_dot_G(token_to_char: typing.Dict[bytes, bytes]):
"""
Vocabulary processor for GPT2 style token mangling, like from \\n to Ġ(used in huggingface bytelevel preprocessors)
"""
token_to_char.update(_huggingface_bytelevel_decoder())


@lru_cache()
def _huggingface_bytelevel_decoder():
"""
I hate legacy code.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n).encode("UTF-8") for n in cs]
for i in range(len(bs)):
bs[i] = bytes([bs[i]])
return dict(zip(cs, bs))
26 changes: 18 additions & 8 deletions src/formatron/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
This module integrates the vllm library by providing convenience utilities.
"""
import collections.abc
import time
import typing
import kbnf
import torch
from vllm import LLM
from formatron.config import EngineGenerationConfig
from formatron.formatter import FormatterBase, FormatterBuilder
from formatron.integrations._utils import get_original_characters
from formatron.integrations.utils import get_original_characters
from vllm.transformers_utils.tokenizer import AnyTokenizer


class FormattersLogitsProcessor:
Expand Down Expand Up @@ -97,26 +96,37 @@ def __call__(self, prompt, generated_tokens, logits):
return logits


def create_engine_vocabulary(llm: LLM) -> kbnf.Vocabulary:
def create_engine_vocabulary(tokenizer: AnyTokenizer,
vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
"""
Create a vocabulary for the KBNF engine.
Args:
tokenizer: The tokenizer.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
tokenizer = llm.get_tokenizer()
vocab = tokenizer.get_vocab()
new_vocab = get_original_characters(vocab)
new_vocab = get_original_characters(vocab, vocab_processors)
return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, {
v: k for k, v in vocab.items()})


def create_formatters_logits_processor(llm: LLM,
formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
configs: typing.Sequence[EngineGenerationConfig] = None) \
configs: typing.Sequence[EngineGenerationConfig] = None,
vocab_processors: typing.Optional[list[typing.Callable]] = None) \
-> FormattersLogitsProcessor:
"""
Create a formatter logits processor.
Args:
llm: The LLM.
formatter_builders: The formatter builders.
configs: The engine generation configurations.
vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
"""
tokenizer = llm.get_tokenizer()
vocab = create_engine_vocabulary(llm)
vocab = create_engine_vocabulary(tokenizer, vocab_processors)
if not isinstance(formatter_builders, collections.abc.Sequence):
formatter_builders = [formatter_builders]
formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
Expand Down

0 comments on commit 7c18e17

Please sign in to comment.