From 5cc901253a1baef52e8b17bb4813d85d588056dc Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 5 Feb 2024 15:03:30 -0800 Subject: [PATCH] Fix for missing EOS token (#408) --- mii/modeling/tokenizers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mii/modeling/tokenizers.py b/mii/modeling/tokenizers.py index 527caec2..46190759 100644 --- a/mii/modeling/tokenizers.py +++ b/mii/modeling/tokenizers.py @@ -41,7 +41,7 @@ def decode(self, tokens: torch.Tensor) -> str: class HFTokenizer(MIITokenizerWrapper): def __init__(self, tokenizer: Union[str, object]) -> None: if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token super().__init__(tokenizer) @@ -51,7 +51,11 @@ def vocab_size(self) -> int: @property def eos_token_id(self) -> int: - return self.tokenizer.eos_token_id + eos_token_attrs = ["eod", "eos_token_id", "eos_token", "eod_id"] + for attr in eos_token_attrs: + if getattr(self.tokenizer, attr, None) is not None: + return getattr(self.tokenizer, attr) + raise ValueError(f"Tokenizer must have one of {eos_token_attrs} attributes.") def encode(self, input: str) -> torch.Tensor: return self.tokenizer.encode(input, return_tensors="pt").flatten()