Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llamacpp caching by making LlamaCppTokenizer an outlines Tokenizer #929

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions outlines/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

import math
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union

import numpy as np
import torch
Expand All @@ -36,37 +36,12 @@
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.models.llamacpp import LlamaCppTokenizer

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCppTokenizer:
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

tokenizer = model.tokenizer()

self.decode = tokenizer.decode

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

def convert_token_to_string(self, token: str) -> str:
return token


class LogitsProcessor:
"""Bias LlamaCpp generation using a finite state machine.
Expand Down
89 changes: 88 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,102 @@
import dataclasses
import pickle
import warnings
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
TypedDict,
Union,
)

from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from llama_cpp import Llama, LogitsProcessorList


class LlamaCppTokenizer(Tokenizer):
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

self.tokenizer = model.tokenizer()

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

# ensure stable ordering of vocabulary
self.vocabulary = {
tok: tok_id
for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1])
}

self._hash = None

def decode(self, token_ids: List[int]) -> List[str]:
decoded_bytes = self.tokenizer.detokenize(token_ids)
return [decoded_bytes.decode("utf-8", errors="ignore")]

def encode(
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True
) -> Tuple[List[int], List[int]]:
if isinstance(prompt, list):
raise NotImplementedError(
"llama-cpp-python tokenizer doesn't support batch tokenization"
)
token_ids = self.tokenizer.tokenize(
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)
# generate attention mask, missing from llama-cpp-python
attention_mask = [
1 if token_id != self.pad_token_id else 0 for token_id in token_ids
]
return token_ids, attention_mask

def convert_token_to_string(self, token: str) -> str:
return token

def __eq__(self, other):
if not isinstance(other, LlamaCppTokenizer):
return False
return self.__getstate__() == other.__getstate__()

def __hash__(self):
if self._hash is None:
self._hash = hash(pickle.dumps(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think pickle.dumps handles dict key and/or set value ordering, so this might not be a good approach.

I think I see what you were trying to do previously with the sorting, but it doesn't matter for serialization. It might only seem to matter because you're mixing the serialization interface with the equivalence check and hashing. That's not necessary, though; you can compare the relevant member objects directly in __eq__ and take special steps that are good for hashing in __hash__ and only there (e.g. json.dumps(..., sort_keys=True) for dicts seems to be favored by many).

Copy link
Contributor Author

@lapp0 lapp0 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you were trying to do previously with the sorting, but it doesn't matter for serialization. It might only seem to matter because you're mixing the serialization interface with the equivalence check and hashing. That's not necessary

It matters because the only reason Tokenizers are serializable via __setstate__ is because their serialized form is used to create a stable hash for the @cached create_states_mapping(regex_string, tokenizer). Tokenizers are never deserialized. They are serialized for hashing.

(e.g. json.dumps(..., sort_keys=True) for dicts seems to be favored by many).

json.dumps is much slower than pickle and pickling a Tokenizer is already 0.25 seconds, which matters because every time we create an FSM index we check the cache which has the pickled tokenizer as a key.

Dicts have stable order since 3.6, and while I successfully experimented with this, I don't know of a guarantee pickle maintains order. How about we revert to sorting then pickling to be safe?

But this discussion brings up an important point. Once other index construction bottlenecks are taken care of by #795 maybe we should address the performance issues I just described. We should only calculate the hash of the serialized tokenizer once. This is much better than serializing a tokenizer every single time create_states_mapping() is called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matters because the only reason Tokenizers are serializable via __setstate__ is because their serialized form is used to create a stable hash for the @cached create_states_mapping(regex_string, tokenizer). Tokenizers are never deserialized. They are serialized for hashing.

You're still mixing up serialization with equivalence and hashing. The hashing we're talking about here (i.e. __hash__) is completely independent of any caching with cache.

Also, if you want to address the potential cache misses due to dict and set ordering, that can be done in CloudpickleDisk. That's where serialization is used for cache indexing purposes.

json.dumps is much slower than pickle and pickling a Tokenizer is already 0.25 seconds, which matters because every time we create an FSM index we check the cache which has the pickled tokenizer as a key.

We can use whatever is sufficiently performant and accurate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we revert to sorting then pickling to be safe?

If that's the best option, then we might as well make sure that self.vocabulary is sorted upon construction/creation. Sometimes these dicts are already sorted by token ID, in which case that canonicalization step would be rather efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #933 to address potential concerns regarding cache misses.

In cloudpickle, dicts are deterministic for versions > 3.7, but sets are not. I pre-sorted the vocabulary in the latest push, and sort the special-tokens set when __getstate__ is called.

Please review the latest changes.

return self._hash

def __getstate__(self):
"""Create a stable representation for outlines.caching"""
return (
self.vocabulary,
self.eos_token_id,
self.eos_token,
self.pad_token_id,
sorted(self.special_tokens),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask why special_tokens isn't sorted in the constructor, but now I don't even see where/how it's being populated.

Copy link
Member

@brandonwillard brandonwillard May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it might be adapt_tokenzier. That's problematic, since it means that these class instances aren't actually immutable, which—at the very least—invalidates the hashability requirement.

It looks like we need to change adapt_tokenizer so that it returns a completely new Tokenizer instance, or perhaps integrate the changes made by adapt_tokenizer into the Tokenizer subclasses directly.

Copy link
Contributor Author

@lapp0 lapp0 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask why special_tokens isn't sorted in the constructor, but now I don't even see where/how it's being populated.

cloudpickle has non-deterministic behavior for sets, (even FrozenSets) we need to convert to a sorted list when serializing to ensure a stable hash.

Oh, it might be adapt_tokenzier. That's problematic, since it means that these class instances aren't actually immutable, which—at the very least—invalidates the hashability requirement.

adapt_tokenizer doesn't apply to llamacpp for now. It's only called in integrations/transformers.py and integrations/vllm.py.

It looks like we need to change adapt_tokenizer so that it returns a completely new Tokenizer instance, or perhaps integrate the changes made by adapt_tokenizer into the Tokenizer subclasses directly.

IMHO we should create a separate issue to unify / correct the behavior and interfaces of tokenizers in general to prevent the scope of this PR from growing too large. This PR doesn't introduce any new problems with LlamaCppTokenizer, but it does fix the tests which have been failing in main all week.

)

def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCppParams(TypedDict, total=False):
suffix: Optional[str]
temperature: float
Expand Down
24 changes: 24 additions & 0 deletions tests/generate/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from importlib import reload

import pytest


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status
55 changes: 55 additions & 0 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,58 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
model = models.llamacpp(repo, model_path)
with pytest.raises(RuntimeError):
generate.choice(model, ["skirt", "dress", "pen", "jacket"])


def test_RegexGuide_caching(model, temp_cache_dir):
import llama_cpp

import outlines.caching
from outlines.fsm.guide import create_states_mapping

assert outlines.caching._caching_enabled

regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

cache = outlines.caching.get_cache()

# Returns (hits, misses)
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache

generator = generate.regex(model, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 1)

model_2 = models.llamacpp(
"Qwen/Qwen1.5-0.5B-Chat-GGUF",
"*q2*.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat"
),
)
generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 2)

# These two different models and tokenizers should not have the same state
# mapping results
assert (
generator.logits_processor.fsm.states_to_token_maps
!= generator_2.logits_processor.fsm.states_to_token_maps
)

generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (1, 2)
assert (
generator_2.logits_processor.fsm.states_to_token_maps
== generator_3.logits_processor.fsm.states_to_token_maps
)

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2
22 changes: 0 additions & 22 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import re
from enum import Enum
from importlib import reload
from typing import List, Union

import pytest
Expand All @@ -15,27 +14,6 @@
from outlines.samplers import beam_search, greedy, multinomial


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status


def test_transformers_integration_text():
rng = torch.Generator()
rng.manual_seed(10000) # Choosen so <EOS> is generated
Expand Down
Loading