Skip to content

Commit

Permalink
Run commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnz authored and rlouf committed Mar 1, 2024
1 parent c1b4ffa commit c4de2e0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,22 @@ class CodeLlamaTokenizerFast: # type: ignore
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs):
def __init__(
self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs
):
if isinstance(tokenizer_or_model_name, str):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.model_name = tokenizer_or_model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_or_model_name, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_or_model_name, **kwargs
)
else:
self.tokenizer = tokenizer_or_model_name

self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

Expand Down Expand Up @@ -112,7 +116,9 @@ def convert_token_to_string(self, token: str) -> str:
def __eq__(self, other):
if isinstance(other, type(self)):
if hasattr(self, "model_name") and hasattr(self, "kwargs"):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return (
other.model_name == self.model_name and other.kwargs == self.kwargs
)
else:
return other.tokenizer == self.tokenizer
return NotImplemented
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def __call__(


def test_transformers_use_existing_model_and_tokenizer():
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

rng = torch.Generator()
rng.manual_seed(10000)
Expand Down

0 comments on commit c4de2e0

Please sign in to comment.