diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index b069cfe34..da02642d2 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -58,7 +58,9 @@ class CodeLlamaTokenizerFast: # type: ignore class TransformerTokenizer(Tokenizer): """Represents a tokenizer for models in the `transformers` library.""" - def __init__(self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs): + def __init__( + self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs + ): if isinstance(tokenizer_or_model_name, str): from transformers import AutoTokenizer @@ -66,10 +68,12 @@ def __init__(self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str self.model_name = tokenizer_or_model_name # TODO: Do something to make this hashable? self.kwargs = kwargs - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_or_model_name, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_or_model_name, **kwargs + ) else: self.tokenizer = tokenizer_or_model_name - + self.eos_token_id = self.tokenizer.eos_token_id self.eos_token = self.tokenizer.eos_token @@ -112,7 +116,9 @@ def convert_token_to_string(self, token: str) -> str: def __eq__(self, other): if isinstance(other, type(self)): if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return other.model_name == self.model_name and other.kwargs == self.kwargs + return ( + other.model_name == self.model_name and other.kwargs == self.kwargs + ) else: return other.tokenizer == self.tokenizer return NotImplemented diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 1dd260e13..32bef74dd 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -618,7 +618,7 @@ def __call__( def test_transformers_use_existing_model_and_tokenizer(): - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer rng = torch.Generator() rng.manual_seed(10000)