-
Notifications
You must be signed in to change notification settings - Fork 485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix llamacpp caching by making LlamaCppTokenizer
an outlines Tokenizer
#929
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,102 @@ | ||
import dataclasses | ||
import pickle | ||
import warnings | ||
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Dict, | ||
Iterator, | ||
List, | ||
Optional, | ||
Set, | ||
Tuple, | ||
TypedDict, | ||
Union, | ||
) | ||
|
||
from typing_extensions import Unpack | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
from outlines.models.tokenizer import Tokenizer | ||
|
||
if TYPE_CHECKING: | ||
from llama_cpp import Llama, LogitsProcessorList | ||
|
||
|
||
class LlamaCppTokenizer(Tokenizer): | ||
def __init__(self, model: "Llama"): | ||
self.eos_token_id = model.token_eos() | ||
self.eos_token = model.tokenizer().decode([self.eos_token_id]) | ||
self.pad_token_id = self.eos_token_id | ||
self.special_tokens: Set[int] = set() | ||
|
||
self.vocabulary: Dict[str, int] = dict() | ||
|
||
self.tokenizer = model.tokenizer() | ||
|
||
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved | ||
try: | ||
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() | ||
except AttributeError: | ||
# ### | ||
for t in range(model.n_vocab()): | ||
token_piece = model.tokenizer().decode([t]) | ||
self.vocabulary[token_piece] = t | ||
|
||
# ensure stable ordering of vocabulary | ||
self.vocabulary = { | ||
tok: tok_id | ||
for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1]) | ||
} | ||
|
||
self._hash = None | ||
|
||
def decode(self, token_ids: List[int]) -> List[str]: | ||
decoded_bytes = self.tokenizer.detokenize(token_ids) | ||
return [decoded_bytes.decode("utf-8", errors="ignore")] | ||
|
||
def encode( | ||
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True | ||
) -> Tuple[List[int], List[int]]: | ||
if isinstance(prompt, list): | ||
raise NotImplementedError( | ||
"llama-cpp-python tokenizer doesn't support batch tokenization" | ||
) | ||
token_ids = self.tokenizer.tokenize( | ||
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special | ||
) | ||
# generate attention mask, missing from llama-cpp-python | ||
attention_mask = [ | ||
1 if token_id != self.pad_token_id else 0 for token_id in token_ids | ||
] | ||
return token_ids, attention_mask | ||
|
||
def convert_token_to_string(self, token: str) -> str: | ||
return token | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, LlamaCppTokenizer): | ||
return False | ||
return self.__getstate__() == other.__getstate__() | ||
|
||
def __hash__(self): | ||
if self._hash is None: | ||
self._hash = hash(pickle.dumps(self)) | ||
return self._hash | ||
|
||
def __getstate__(self): | ||
"""Create a stable representation for outlines.caching""" | ||
return ( | ||
self.vocabulary, | ||
self.eos_token_id, | ||
self.eos_token, | ||
self.pad_token_id, | ||
sorted(self.special_tokens), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was going to ask why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it might be It looks like we need to change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
IMHO we should create a separate issue to unify / correct the behavior and interfaces of tokenizers in general to prevent the scope of this PR from growing too large. This PR doesn't introduce any new problems with |
||
) | ||
|
||
def __setstate__(self, state): | ||
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") | ||
|
||
|
||
class LlamaCppParams(TypedDict, total=False): | ||
suffix: Optional[str] | ||
temperature: float | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from importlib import reload | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def temp_cache_dir(): | ||
import os | ||
import tempfile | ||
|
||
import outlines.caching | ||
import outlines.fsm.guide | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
os.environ["OUTLINES_CACHE_DIR"] = tempdir | ||
outlines.caching.get_cache.cache_clear() | ||
reload(outlines) | ||
reload(outlines.fsm.guide) | ||
cache_status = outlines.caching._caching_enabled | ||
try: | ||
outlines.caching._caching_enabled = True | ||
yield | ||
finally: | ||
outlines.caching._caching_enabled = cache_status |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think
pickle.dumps
handlesdict
key and/orset
value ordering, so this might not be a good approach.I think I see what you were trying to do previously with the sorting, but it doesn't matter for serialization. It might only seem to matter because you're mixing the serialization interface with the equivalence check and hashing. That's not necessary, though; you can compare the relevant member objects directly in
__eq__
and take special steps that are good for hashing in__hash__
and only there (e.g.json.dumps(..., sort_keys=True)
fordict
s seems to be favored by many).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It matters because the only reason
Tokenizer
s are serializable via__setstate__
is because their serialized form is used to create a stable hash for the@cache
dcreate_states_mapping(regex_string, tokenizer)
. Tokenizers are never deserialized. They are serialized for hashing.json.dumps
is much slower than pickle and pickling aTokenizer
is already 0.25 seconds, which matters because every time we create an FSM index we check the cache which has the pickled tokenizer as a key.Dicts have stable order since 3.6, and while I successfully experimented with this, I don't know of a guarantee pickle maintains order. How about we revert to sorting then pickling to be safe?
But this discussion brings up an important point. Once other index construction bottlenecks are taken care of by #795 maybe we should address the performance issues I just described. We should only calculate the hash of the serialized tokenizer once. This is much better than serializing a tokenizer every single time
create_states_mapping()
is called.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're still mixing up serialization with equivalence and hashing. The hashing we're talking about here (i.e.
__hash__
) is completely independent of any caching withcache
.Also, if you want to address the potential cache misses due to
dict
andset
ordering, that can be done inCloudpickleDisk
. That's where serialization is used for cache indexing purposes.We can use whatever is sufficiently performant and accurate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the best option, then we might as well make sure that
self.vocabulary
is sorted upon construction/creation. Sometimes thesedict
s are already sorted by token ID, in which case that canonicalization step would be rather efficient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created #933 to address potential concerns regarding cache misses.
In
cloudpickle
, dicts are deterministic for versions > 3.7, but sets are not. I pre-sorted the vocabulary in the latest push, and sort the special-tokens set when__getstate__
is called.Please review the latest changes.