Skip to content

Commit

Permalink
Support generating multi-byte utf8 characters
Browse files Browse the repository at this point in the history
  • Loading branch information
ksvladimir authored and rlouf committed Mar 14, 2024
1 parent d7295a7 commit 043117f
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 30 deletions.
14 changes: 11 additions & 3 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from outlines import grammars
from outlines.caching import cache
from outlines.fsm.regex import create_fsm_index_tokenizer, make_deterministic_fsm
from outlines.fsm.regex import (
create_fsm_index_tokenizer,
make_byte_level_fsm,
make_deterministic_fsm,
)

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer
Expand Down Expand Up @@ -114,7 +118,10 @@ def create_states_mapping(
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
Expand Down Expand Up @@ -216,7 +223,8 @@ def create_states_mapping_from_interegular_fsm(
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_fsm, _ = make_deterministic_fsm(fsm.reduce())
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
Expand Down
118 changes: 96 additions & 22 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from collections import namedtuple
from functools import lru_cache
from typing import (
Expand Down Expand Up @@ -79,21 +80,21 @@ def fsm_info(self):
if self._fsm_info is None:
flat_transition_map_items = np.fromiter(
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
dtype=np.dtype("i8, i8, i8"),
dtype=np.dtype("int64, int64, int64"),
)
trans_key_to_states_items = np.fromiter(
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("i8, i8"),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U1, i8"),
dtype=np.dtype("U2, int64"),
)
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
nb_finals,
Expand All @@ -108,7 +109,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)


@numba.njit(cache=True)
Expand All @@ -132,7 +133,9 @@ def create_fsm_info(
(trans_key_and_state[0], trans_key_and_state[1])
] = trans_key_and_state[2]

alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -195,7 +198,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}"
return f"{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -415,7 +418,7 @@ def _walk_fsm(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
input_string: Sequence[str],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand Down Expand Up @@ -449,7 +452,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
input_string: str,
input_string: Sequence[str],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand Down Expand Up @@ -651,12 +654,12 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary.items():
for token, token_ids in vocabulary:
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
Expand All @@ -679,7 +682,7 @@ def state_scan_tokens(

def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand Down Expand Up @@ -715,30 +718,101 @@ def create_fsm_index_end_to_end(
return states_to_token_subsets


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^�+$")


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
@lru_cache()
def gpt2_bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


@lru_cache()
def gpt2_unicode_to_bytes():
return {v: k for k, v in gpt2_bytes_to_unicode().items()}


# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(tokenizer: "Tokenizer"):
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue

token_str = tokenizer.convert_token_to_string(token)
token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string(
token
)

if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
# invalid utf-8 sequences are replaced with � (\ufffd), but there
# might also be tokens specifically for �, ��, ���, etc.
if "\ufffd" in token_str and not re_replacement_seq.match(token):
if re_llama_byte_token.match(token):
# llama-like tokenizers have <0xXX> tokens for all
# bytes >= 0x80 and represent all incomplete utf-8
# sequences using such tokens
token_bytes = [int(token[3:5], 16)]
else:
# gpt2-like tokenizers have multi-byte tokens that can
# have a mix of full and incomplete utf-8 characters,
# for example, b` \xf0` can be one token; these tokenizers
# map each byte to a valid utf-8 character
token_bytes = cast(
List[int], [gpt2_unicode_to_bytes().get(c) for c in token]
)
if None in token_bytes:
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
empty_token_ids.add(numba.int64(token_idx))

return vocabulary, empty_token_ids
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))

return vocabulary_nb, empty_token_ids


def create_fsm_index_tokenizer(
Expand Down
93 changes: 93 additions & 0 deletions tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,99 @@ def convert_token_to_string(self, token):
assert fsm.is_final_state(state) is True


def test_regex_multi_byte_llama_like():
class MockTokenizer:
vocabulary = {
"1": 1,
"a": 2,
"eos": 3,
"😍": 4,
"<0xF0>": 5,
"<0x9F>": 6,
"<0x98>": 7,
"<0x88>": 8, # 😈
"\ufffd": 9,
"\ufffd\ufffd": 10,
}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
if token[0] == "<":
return "\ufffd"
return token

regex_str = "[😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 4: 2},
1: {6: 3},
3: {7: 4},
4: {8: 2},
}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert instruction.tokens == [5, 4]

assert fsm.get_next_state(state=0, token_id=5) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_multi_byte_gpt2_like():
class MockTokenizer:
vocabulary = {
"1": 1,
"a": 2,
"eos": 3,
"😍": 4,
" ": 5,
"\ufffd": 6,
"\ufffd\ufffd": 7,
"ðŁĺ": 8,
"Ī": 9, # '😈'
"Ġð": 10,
"ŁĺĪ": 11, # ' 😈'
}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
if self.vocabulary[token] >= 8:
return "\ufffd"
return token

regex_str = " [😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 10: 2},
1: {8: 5, 4: 3},
2: {11: 3},
5: {9: 3},
}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert instruction.tokens == [5, 10]

assert fsm.get_next_state(state=0, token_id=5) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""

Expand Down
Loading

0 comments on commit 043117f

Please sign in to comment.