From 043117f47520c159028bb61bf11db0559d7b6deb Mon Sep 17 00:00:00 2001 From: Volodymyr Kuznetsov Date: Mon, 11 Mar 2024 16:50:15 -0700 Subject: [PATCH] Support generating multi-byte utf8 characters --- outlines/fsm/guide.py | 14 ++++- outlines/fsm/regex.py | 118 ++++++++++++++++++++++++++++++++-------- tests/fsm/test_guide.py | 93 +++++++++++++++++++++++++++++++ tests/fsm/test_regex.py | 103 +++++++++++++++++++++++++++++++++-- 4 files changed, 298 insertions(+), 30 deletions(-) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 0690bdf6e..3ef0aed95 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -6,7 +6,11 @@ from outlines import grammars from outlines.caching import cache -from outlines.fsm.regex import create_fsm_index_tokenizer, make_deterministic_fsm +from outlines.fsm.regex import ( + create_fsm_index_tokenizer, + make_byte_level_fsm, + make_deterministic_fsm, +) if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -114,7 +118,10 @@ def create_states_mapping( The parameters of the function are used for caching purpose """ regex_pattern = interegular.parse_pattern(regex_string) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + byte_fsm = make_byte_level_fsm( + regex_pattern.to_fsm().reduce(), keep_utf8=True + ) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( regex_fsm, tokenizer ) @@ -216,7 +223,8 @@ def create_states_mapping_from_interegular_fsm( """Create the variables related to the mapping between states and tokens The parameters of the function are used for caching purpose """ - regex_fsm, _ = make_deterministic_fsm(fsm.reduce()) + byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( regex_fsm, tokenizer ) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 562ad8d47..0c133ef65 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -1,3 +1,4 @@ +import re from collections import namedtuple from functools import lru_cache from typing import ( @@ -79,11 +80,11 @@ def fsm_info(self): if self._fsm_info is None: flat_transition_map_items = np.fromiter( ((a[0], a[1], b) for a, b in self.flat_transition_map.items()), - dtype=np.dtype("i8, i8, i8"), + dtype=np.dtype("int64, int64, int64"), ) trans_key_to_states_items = np.fromiter( ((k, z) for k, v in self.trans_key_to_states.items() for z in v), - dtype=np.dtype("i8, i8"), + dtype=np.dtype("int64, int64"), ) alphabet_symbol_mapping_items = np.fromiter( ( @@ -91,9 +92,9 @@ def fsm_info(self): for it in self.alphabet._symbol_mapping.items() if it[0] != anything_else ), - dtype=np.dtype("U1, i8"), + dtype=np.dtype("U2, int64"), ) - nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8")) + nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) self.__dict__["_fsm_info"] = create_fsm_info( self.initial, nb_finals, @@ -108,7 +109,7 @@ def fsm_info(self): nb_int_list_type = numba.types.ListType(numba.int64) nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unichar_1_type = numba.types.UnicodeCharSeq(1) +nb_unichar_2_type = numba.types.UnicodeCharSeq(2) @numba.njit(cache=True) @@ -132,7 +133,9 @@ def create_fsm_info( (trans_key_and_state[0], trans_key_and_state[1]) ] = trans_key_and_state[2] - alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64) + # use 2-char strings so that we can represent incomplete utf-8 sequences + # as 2-hex-digit pairs + alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64) for symbol_and_trans_key in alphabet_symbol_mapping_items: alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] @@ -195,7 +198,7 @@ def transition_trie_setdefault( def byte_symbol(byte: int) -> str: - return f"{byte:02X}" + return f"{byte:02X}" if byte >= 0x80 else chr(byte) def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: @@ -415,7 +418,7 @@ def _walk_fsm( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - input_string: str, + input_string: Sequence[str], start_state: int, full_match: bool = True, ) -> List[int]: @@ -449,7 +452,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - input_string: str, + input_string: Sequence[str], start_state: int, full_match: bool = True, ) -> List[int]: @@ -651,12 +654,12 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: Dict[str, List[int]], + vocabulary: List[Tuple[Sequence[str], Sequence[int]]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for token, token_ids in vocabulary.items(): + for token, token_ids in vocabulary: state_seq = _walk_fsm( fsm_transitions, alphabet_symbol_mapping, @@ -679,7 +682,7 @@ def state_scan_tokens( def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: Dict[str, List[int]], + vocabulary: List[Tuple[Sequence[str], Sequence[int]]], ) -> Dict[int, Set[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" @@ -715,30 +718,101 @@ def create_fsm_index_end_to_end( return states_to_token_subsets +re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") +re_replacement_seq = re.compile(r"^�+$") + + +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +@lru_cache() +def gpt2_bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +@lru_cache() +def gpt2_unicode_to_bytes(): + return {v: k for k, v in gpt2_bytes_to_unicode().items()} + + # TODO: Cannot cache typed collections to disk, yet. See # https://github.com/numba/numba/issues/4698 @lru_cache -def reduced_vocabulary(tokenizer: "Tokenizer"): +def reduced_vocabulary( + tokenizer: "Tokenizer", +) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" - vocabulary = numba.typed.Dict.empty( - numba.types.string, numba.types.ListType(numba.int64) - ) empty_token_ids = set() + vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} for token, token_idx in tokenizer.vocabulary.items(): if token in tokenizer.special_tokens: continue - token_str = tokenizer.convert_token_to_string(token) + token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( + token + ) if token_str: - vocabulary.setdefault( - token_str, - numba.typed.List.empty_list(numba.int64), - ).append(numba.int64(token_idx)) + # invalid utf-8 sequences are replaced with � (\ufffd), but there + # might also be tokens specifically for �, ��, ���, etc. + if "\ufffd" in token_str and not re_replacement_seq.match(token): + if re_llama_byte_token.match(token): + # llama-like tokenizers have <0xXX> tokens for all + # bytes >= 0x80 and represent all incomplete utf-8 + # sequences using such tokens + token_bytes = [int(token[3:5], 16)] + else: + # gpt2-like tokenizers have multi-byte tokens that can + # have a mix of full and incomplete utf-8 characters, + # for example, b` \xf0` can be one token; these tokenizers + # map each byte to a valid utf-8 character + token_bytes = cast( + List[int], [gpt2_unicode_to_bytes().get(c) for c in token] + ) + if None in token_bytes: + raise RuntimeError( + f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" + ) + token_str = tuple(byte_symbol(b) for b in token_bytes) + + vocabulary.setdefault(token_str, []).append(token_idx) else: empty_token_ids.add(numba.int64(token_idx)) - return vocabulary, empty_token_ids + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + nb_unichar_2_type[:], + numba.int64[:], + ) + ) + ) + for token_tuple, token_ids in vocabulary.items(): + token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token_tuple_np, token_ids_np)) + + return vocabulary_nb, empty_token_ids def create_fsm_index_tokenizer( diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 4be5259d9..aabf0446c 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -67,6 +67,99 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) is True +def test_regex_multi_byte_llama_like(): + class MockTokenizer: + vocabulary = { + "1": 1, + "a": 2, + "eos": 3, + "😍": 4, + "<0xF0>": 5, + "<0x9F>": 6, + "<0x98>": 7, + "<0x88>": 8, # 😈 + "\ufffd": 9, + "\ufffd\ufffd": 10, + } + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + if token[0] == "<": + return "\ufffd" + return token + + regex_str = "[😁-😎]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + assert fsm.states_to_token_maps == { + 0: {5: 1, 4: 2}, + 1: {6: 3}, + 3: {7: 4}, + 4: {8: 2}, + } + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert instruction.tokens == [5, 4] + + assert fsm.get_next_state(state=0, token_id=5) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_multi_byte_gpt2_like(): + class MockTokenizer: + vocabulary = { + "1": 1, + "a": 2, + "eos": 3, + "😍": 4, + " ": 5, + "\ufffd": 6, + "\ufffd\ufffd": 7, + "ðŁĺ": 8, + "Ī": 9, # '😈' + "Ġð": 10, + "ŁĺĪ": 11, # ' 😈' + } + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + if self.vocabulary[token] >= 8: + return "\ufffd" + return token + + regex_str = " [😁-😎]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + assert fsm.states_to_token_maps == { + 0: {5: 1, 10: 2}, + 1: {8: 5, 4: 3}, + 2: {11: 3}, + 5: {9: 3}, + } + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert instruction.tokens == [5, 10] + + assert fsm.get_next_state(state=0, token_id=5) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + def test_regex_final_state(): """Make sure that the FSM stays in the final state as we keep generating""" diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index aee04309f..2fc8a5384 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,5 +1,8 @@ +from typing import Sequence + import interegular import numba +import numpy as np import pytest from transformers import AutoTokenizer @@ -9,6 +12,7 @@ create_fsm_index_tokenizer, fsm_union, get_sub_fsms_from_seq, + make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, walk_fsm, @@ -16,13 +20,17 @@ from outlines.models.transformers import TransformerTokenizer +def identity(s): + return s + + def to_bytes(s): return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")] def walk_fsm_numba( fsm, - input_string: str, + input_string: Sequence[str], start_state: int, full_match: bool = True, ): @@ -88,6 +96,42 @@ def test_walk_fsm(function): assert res == tuple() +@pytest.mark.parametrize( + "function", + [ + walk_fsm, + walk_fsm_numba, + ], +) +@pytest.mark.parametrize( + "transform", + [ + identity, + to_bytes, + ], +) +def test_walk_fsm_multi_bytes(function, transform): + regex_pattern = interegular.parse_pattern("😂|[😇-😍][😈-😍]*") + str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) + + res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True)) + assert res[-1:] == (1,) + + res = tuple( + function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False) + ) + assert res[-1:] == (1,) + + res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True)) + assert res == tuple() + + res = tuple( + function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True) + ) + assert res == tuple() + + def test_get_sub_fsms_from_seq(): name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) @@ -257,16 +301,61 @@ def test_create_fsm_index_end_to_end(): "": numba.typed.List([4]), } - vocabulary_nb = numba.typed.Dict.empty( - numba.types.string, numba.types.ListType(numba.int64) + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + numba.types.UnicodeCharSeq(2)[:], + numba.int64[:], + ) + ) ) - vocabulary_nb.update(vocabulary) + for token_tuple, token_ids in vocabulary.items(): + token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token_tuple_np, token_ids_np)) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} +def test_create_fsm_index_end_to_end_multi_byte(): + regex_str = "😇| [😈-😍][😇-😎]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) + + vocabulary = { + "blah": numba.typed.List([0]), + "😈a": numba.typed.List([1]), + "😇": numba.typed.List([2]), + "😍": numba.typed.List([3]), + ("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍' + " 😍": numba.typed.List([5]), + (" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍' + (" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete + "": numba.typed.List([8]), + } + + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + numba.types.UnicodeCharSeq(2)[:], + numba.int64[:], + ) + ) + ) + for token_tuple, token_ids in vocabulary.items(): + token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token_tuple_np, token_ids_np)) + + res = create_fsm_index_end_to_end(byte_fsm.fsm_info, vocabulary_nb) + + assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} + + def test_create_fsm_index_tokenizer(): # The combined regular expressions of a lexer state in a Python grammar regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" @@ -274,15 +363,19 @@ def test_create_fsm_index_tokenizer(): regex_pattern = interegular.parse_pattern(regex_str) # Not reduced, so that there are many states regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) num_fsm_states = len(regex_fsm.states) assert num_fsm_states == 220 + num_bytes_fsm_states = len(bytes_fsm.states) + assert num_bytes_fsm_states == 235 + tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer + bytes_fsm, tokenizer ) assert not empty_token_ids