diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 74498726d..8e18c33e7 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -26,7 +26,7 @@ """ import math -from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import numpy as np import torch @@ -36,47 +36,12 @@ from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer if TYPE_CHECKING: from llama_cpp import Llama -class LlamaCppTokenizer: - def __init__(self, model: "Llama"): - self.eos_token_id = model.token_eos() - self.eos_token = model.tokenizer().decode([self.eos_token_id]) - self.pad_token_id = self.eos_token_id - self.special_tokens: Set[int] = set() - - self.vocabulary: Dict[str, int] = dict() - - tokenizer = model.tokenizer() - - self.decode = tokenizer.decode - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - try: - self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() - except AttributeError: - # ### - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - def convert_token_to_string(self, token: str) -> str: - return token - - def __getstate__(self): - """Allow tokenizer to be used as hash key by excluding self.decode""" - return ( - self.vocabulary.items(), - self.eos_token_id, - self.eos_token, - self.pad_token_id, - sorted(self.special_tokens), - ) - - class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 5920f08d6..5a63e3cfe 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,15 +1,96 @@ import dataclasses +import pickle import warnings -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if not isinstance(other, LlamaCppTokenizer): + return False + return self.__getstate__() == other.__getstate__() + + def __hash__(self): + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + sorted(self.vocabulary), + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + class LlamaCppParams(TypedDict, total=False): suffix: Optional[str] temperature: float diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py new file mode 100644 index 000000000..ef7e40eed --- /dev/null +++ b/tests/generate/conftest.py @@ -0,0 +1,24 @@ +from importlib import reload + +import pytest + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 531bf8fb9..b7eb8b3cb 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -281,9 +281,56 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) -def test_create_states_mapping_llamacpp_tokenizer_regression(model): - """Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping""" +def test_RegexGuide_caching(model, temp_cache_dir): + import llama_cpp + + import outlines.caching from outlines.fsm.guide import create_states_mapping - from outlines.integrations.llamacpp import LlamaCppTokenizer - create_states_mapping("a", LlamaCppTokenizer(model.model)) + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + generator = generate.regex(model, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.llamacpp( + "Qwen/Qwen1.5-0.5B-Chat-GGUF", + "*q2*.gguf", + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + "Qwen/Qwen1.5-0.5B-Chat" + ), + ) + generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert ( + generator.logits_processor.fsm.states_to_token_maps + != generator_2.logits_processor.fsm.states_to_token_maps + ) + + generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (1, 2) + assert ( + generator_2.logits_processor.fsm.states_to_token_maps + == generator_3.logits_processor.fsm.states_to_token_maps + ) + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index cee3ca312..da08bed71 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,7 +1,6 @@ import datetime import re from enum import Enum -from importlib import reload from typing import List, Union import pytest @@ -15,27 +14,6 @@ from outlines.samplers import beam_search, greedy, multinomial -@pytest.fixture -def temp_cache_dir(): - import os - import tempfile - - import outlines.caching - import outlines.fsm.guide - - with tempfile.TemporaryDirectory() as tempdir: - os.environ["OUTLINES_CACHE_DIR"] = tempdir - outlines.caching.get_cache.cache_clear() - reload(outlines) - reload(outlines.fsm.guide) - cache_status = outlines.caching._caching_enabled - try: - outlines.caching._caching_enabled = True - yield - finally: - outlines.caching._caching_enabled = cache_status - - def test_transformers_integration_text(): rng = torch.Generator() rng.manual_seed(10000) # Choosen so is generated