Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llamacpp caching by making LlamaCppTokenizer an outlines Tokenizer #929

Merged
merged 2 commits into from
May 31, 2024

Conversation

lapp0
Copy link
Contributor

@lapp0 lapp0 commented May 29, 2024

Fixes #922

Problem

integrations/llamacpp.py / LlamaCppTokenizer has self.decode = model.tokenizer.decode. This is not pickle-able. Attempting to pass a LlamaCppTokenizer to a @cached function results in cloudpickle.dumps(tokenizer) which causes ValueError: ctypes objects containing pointers cannot be pickled error.

Locally most test_integration_llamacpp.py fail. (full test failure details in #922)

Solution

Make LlamaCppTokenizer subclass outlines.models.tokenizer.Tokenizer and implement its abstract methods. Specifically __getstate__ is necessary to ensure cloudpickle.dumps works and hashing is possible.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the most important thing we need is a test that helps avoid issues like this in the future.

outlines/integrations/llamacpp.py Outdated Show resolved Hide resolved
@rlouf rlouf added bug llama.cpp Related to the `llama.cpp` integration labels May 30, 2024
@lapp0
Copy link
Contributor Author

lapp0 commented May 30, 2024

@brandonwillard should be ready for re-review when you have a chance by the way. A minimal reproducer test added.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this, but shouldn't LlamaCppTokenizer extend our Tokenizer class in the first place? If that's the case, we can use that interface and simply copy the relevant parts (e.g. serialization methods) of our transformers Tokenizer subclass and be done.

@lapp0
Copy link
Contributor Author

lapp0 commented May 30, 2024

I just noticed this, but shouldn't LlamaCppTokenizer extend our Tokenizer class in the first place? If that's the case, we can use that interface and simply copy the relevant parts (e.g. serialization methods) of our transformers Tokenizer subclass and be done.

I agree that we should work towards that in general (relevant discussion #695)

vLLM also uses a separate tokenizer which doesn't guarantee the same interface as outlines.models.tokenizer.Tokenizer. I can open an issue for ensuring all tokenizers are subclasses of outlines Tokenizer.

We should address this first to ensure tests are passing in main though.

@brandonwillard
Copy link
Member

I agree that we should work towards that in general (relevant discussion #695)

We can start that work here. It should only involve a change to LlamaCppTokenizer(Tokenizer) and adding proper __eq__, __hash__, __getstate__, and __setstate__ implementations adapted from TransformerTokenizer. Finally, the test test_RegexGuide_caching can be copied and used for LlamaCppTokenizer.

@lapp0 lapp0 changed the title Fix llamacpp caching by making LlamaCppTokenizer pickleable Fix llamacpp caching by making LlamaCppTokenizer an outlines Tokenizer May 30, 2024
@lapp0
Copy link
Contributor Author

lapp0 commented May 30, 2024

I've implemented LlamaCppTokenizer as a subclass of outlines Tokenizer.

It cannot be loaded from disk because llama_cpp tokenizers aren't pickleable. __setstate__ and __getstate__ are implemented solely for use with outlines.caching.

decode() and encode() have been adapted to support the type expectations of Tokenizer, allowing for uniformity between models.llamacpp and models.transformers use within sequence generators. However I left batch encoding (tokenizer.encode(List[str])) as NotImplementedError, as llama_cpp doesn't track the tokenizers padding_side, therefore I cannot set the padding in batch encoded sequences safely.

@lapp0 lapp0 force-pushed the fix-922 branch 3 times, most recently from fdee645 to 9333a25 Compare May 30, 2024 20:29
Comment on lines 70 to 71
def __eq__(self, other):
return hash(self) == hash(other)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be consistent due to hash conflicts. The members/fields of this class—aside from self.tokenizer—should be sufficient to define this method, just as they are for __getstate__ and __hash__. (If they're not, then we need to add any missing members/fields to this class.)

N.B. See __eq__ and __hash__ for more info.

tests/models/test_llamacpp.py Outdated Show resolved Hide resolved
@lapp0 lapp0 force-pushed the fix-922 branch 2 times, most recently from 00d1597 to 280611f Compare May 30, 2024 21:21
@lapp0
Copy link
Contributor Author

lapp0 commented May 30, 2024

Thanks for your feedback @brandonwillard !

I've incorporated your requested changes.

outlines/models/llamacpp.py Outdated Show resolved Hide resolved
def __hash__(self):
# cache object hash
if self._hash is None:
self._hash = hash(pickle.dumps(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think pickle.dumps handles dict key and/or set value ordering, so this might not be a good approach.

I think I see what you were trying to do previously with the sorting, but it doesn't matter for serialization. It might only seem to matter because you're mixing the serialization interface with the equivalence check and hashing. That's not necessary, though; you can compare the relevant member objects directly in __eq__ and take special steps that are good for hashing in __hash__ and only there (e.g. json.dumps(..., sort_keys=True) for dicts seems to be favored by many).

Copy link
Contributor Author

@lapp0 lapp0 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you were trying to do previously with the sorting, but it doesn't matter for serialization. It might only seem to matter because you're mixing the serialization interface with the equivalence check and hashing. That's not necessary

It matters because the only reason Tokenizers are serializable via __setstate__ is because their serialized form is used to create a stable hash for the @cached create_states_mapping(regex_string, tokenizer). Tokenizers are never deserialized. They are serialized for hashing.

(e.g. json.dumps(..., sort_keys=True) for dicts seems to be favored by many).

json.dumps is much slower than pickle and pickling a Tokenizer is already 0.25 seconds, which matters because every time we create an FSM index we check the cache which has the pickled tokenizer as a key.

Dicts have stable order since 3.6, and while I successfully experimented with this, I don't know of a guarantee pickle maintains order. How about we revert to sorting then pickling to be safe?

But this discussion brings up an important point. Once other index construction bottlenecks are taken care of by #795 maybe we should address the performance issues I just described. We should only calculate the hash of the serialized tokenizer once. This is much better than serializing a tokenizer every single time create_states_mapping() is called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matters because the only reason Tokenizers are serializable via __setstate__ is because their serialized form is used to create a stable hash for the @cached create_states_mapping(regex_string, tokenizer). Tokenizers are never deserialized. They are serialized for hashing.

You're still mixing up serialization with equivalence and hashing. The hashing we're talking about here (i.e. __hash__) is completely independent of any caching with cache.

Also, if you want to address the potential cache misses due to dict and set ordering, that can be done in CloudpickleDisk. That's where serialization is used for cache indexing purposes.

json.dumps is much slower than pickle and pickling a Tokenizer is already 0.25 seconds, which matters because every time we create an FSM index we check the cache which has the pickled tokenizer as a key.

We can use whatever is sufficiently performant and accurate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we revert to sorting then pickling to be safe?

If that's the best option, then we might as well make sure that self.vocabulary is sorted upon construction/creation. Sometimes these dicts are already sorted by token ID, in which case that canonicalization step would be rather efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #933 to address potential concerns regarding cache misses.

In cloudpickle, dicts are deterministic for versions > 3.7, but sets are not. I pre-sorted the vocabulary in the latest push, and sort the special-tokens set when __getstate__ is called.

Please review the latest changes.

self.eos_token_id,
self.eos_token,
self.pad_token_id,
sorted(self.special_tokens),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask why special_tokens isn't sorted in the constructor, but now I don't even see where/how it's being populated.

Copy link
Member

@brandonwillard brandonwillard May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it might be adapt_tokenzier. That's problematic, since it means that these class instances aren't actually immutable, which—at the very least—invalidates the hashability requirement.

It looks like we need to change adapt_tokenizer so that it returns a completely new Tokenizer instance, or perhaps integrate the changes made by adapt_tokenizer into the Tokenizer subclasses directly.

Copy link
Contributor Author

@lapp0 lapp0 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask why special_tokens isn't sorted in the constructor, but now I don't even see where/how it's being populated.

cloudpickle has non-deterministic behavior for sets, (even FrozenSets) we need to convert to a sorted list when serializing to ensure a stable hash.

Oh, it might be adapt_tokenzier. That's problematic, since it means that these class instances aren't actually immutable, which—at the very least—invalidates the hashability requirement.

adapt_tokenizer doesn't apply to llamacpp for now. It's only called in integrations/transformers.py and integrations/vllm.py.

It looks like we need to change adapt_tokenizer so that it returns a completely new Tokenizer instance, or perhaps integrate the changes made by adapt_tokenizer into the Tokenizer subclasses directly.

IMHO we should create a separate issue to unify / correct the behavior and interfaces of tokenizers in general to prevent the scope of this PR from growing too large. This PR doesn't introduce any new problems with LlamaCppTokenizer, but it does fix the tests which have been failing in main all week.

@lapp0 lapp0 mentioned this pull request May 31, 2024
7 tasks
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge this now in order to prevent any more errors, but we need some follow-up issues for the mutability inconsistencies related to these types.

@brandonwillard brandonwillard merged commit 3a7d83b into dottxt-ai:main May 31, 2024
5 checks passed
@lapp0
Copy link
Contributor Author

lapp0 commented Jun 1, 2024

I agree 100% #936

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug llama.cpp Related to the `llama.cpp` integration
Projects
None yet
Development

Successfully merging this pull request may close these issues.

llamacpp tests failing in main
3 participants