Skip to content

Commit

Permalink
Merge branch 'solve-833' into main-plus-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 31, 2024
2 parents 7723ce8 + da2608d commit 0bde0fd
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 128 deletions.
9 changes: 8 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from outlines.fsm.regex import (
fsm_union,
get_sub_fsms_from_seq,
get_token_transitions,
make_deterministic_fsm,
walk_fsm,
)
Expand Down Expand Up @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None)

text_part = text[start_pos:]

text_transitions = get_token_transitions(
self.fsm.fsm_info.alphabet_symbol_mapping,
self.fsm.fsm_info.alphabet_anything_value,
text_part,
)

state_seq = walk_fsm(
self.fsm,
text_part,
text_transitions,
start_state,
full_match=self.match_whole,
)
Expand Down
114 changes: 78 additions & 36 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def fsm_info(self):
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U2, int64"),
)
alphabet_symbol_mapping_items = [
(k, v)
for k, v in self.alphabet._symbol_mapping.items()
if k != anything_else
]
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
Expand All @@ -110,7 +107,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)
nb_unicode_type = numba.types.unicode_type


@numba.njit(cache=True)
Expand All @@ -136,7 +133,7 @@ def create_fsm_info(

# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -199,7 +196,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -415,21 +412,19 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -453,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -463,13 +458,11 @@ def walk_fsm(
accepted_states: List[int] = []
last_final_idx: int = 0

alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -655,24 +648,25 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_trans_key_seq,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token):
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in token_ids:
Expand All @@ -681,9 +675,51 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_token_transitions(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
token_str: str,
) -> Sequence[int]:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]
return trans_key_seq_array


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Sequence[int]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq_array = get_token_transitions(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand All @@ -699,6 +735,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -709,6 +751,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
tokens_trans_key_seqs,
start_state,
)

Expand Down Expand Up @@ -771,7 +814,7 @@ def gpt2_unicode_to_bytes():
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
Expand Down Expand Up @@ -804,7 +847,7 @@ def reduced_vocabulary(
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
Expand All @@ -813,15 +856,14 @@ def reduced_vocabulary(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
nb_unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token_str, token_ids_np))

return vocabulary_nb, empty_token_ids

Expand Down
29 changes: 2 additions & 27 deletions outlines/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

import math
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union

import numpy as np
import torch
Expand All @@ -36,37 +36,12 @@
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.models.llamacpp import LlamaCppTokenizer

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCppTokenizer:
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

tokenizer = model.tokenizer()

self.decode = tokenizer.decode

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

def convert_token_to_string(self, token: str) -> str:
return token


class LogitsProcessor:
"""Bias LlamaCpp generation using a finite state machine.
Expand Down
Loading

0 comments on commit 0bde0fd

Please sign in to comment.