diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 12463dd78..0661d08f5 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -8,6 +8,7 @@ import os import sys from dataclasses import dataclass +from enum import Enum from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union @@ -237,13 +238,16 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args +class TokenizerType(Enum): + NONE = 0 + TIKTOKEN = 1 + SENTENCEPIECE = 2 + HF_TOKENIZER = 3 @dataclass class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None - is_sentencepiece: bool = False - is_tiktoken: bool = False - is_hf_tokenizer: bool = False + tokenizer_type: TokenizerType = TokenizerType.NONE t: Optional[Any] = None def __post_init__(self): @@ -251,9 +255,7 @@ def __post_init__(self): from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) - self.is_tiktoken = True - self.is_sentencepiece = False - self.is_hf_tokenizer = False + self.tokenizer_type = TokenizerType.TIKTOKEN return except: pass @@ -262,9 +264,7 @@ def __post_init__(self): from sentencepiece import SentencePieceProcessor self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) - self.is_tiktoken = False - self.is_sentencepiece = True - self.is_hf_tokenizer = False + self.tokenizer_type = TokenizerType.SENTENCEPIECE return except: pass @@ -273,18 +273,19 @@ def __post_init__(self): from tokenizer.hf_tokenizer import HFTokenizer self.t = HFTokenizer(str(self.tokenizer_path)) - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = True + self.tokenizer_type = TokenizerType.HF_TOKENIZER return except: pass - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = False - self.t = None - return + def is_tiktoken(self) -> bool: + return self.tokenizer_type == TokenizerType.TIKTOKEN + + def is_sentencepiece(self) -> bool: + return self.tokenizer_type == TokenizerType.SENTENCEPIECE + + def is_hf_tokenizer(self) -> bool: + return self.tokenizer_type == TokenizerType.HF_TOKENIZER def validate_model( self, @@ -294,12 +295,13 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: + if self.tokenizer_type == TokenizerType.NONE: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.is_tiktoken - is_sentencepiece = self.is_sentencepiece - is_hf_tokenizer = self.is_hf_tokenizer + is_tiktoken = self.is_tiktoken() + is_sentencepiece = self.is_sentencepiece() + is_hf_tokenizer = self.is_hf_tokenizer() + use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) @@ -651,13 +653,13 @@ def do_nothing(max_batch_size, max_seq_length): model = torch.load(builder_args.snapshot_path, weights_only=False) except Exception: raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") - # _active_backend() does not allow DSO & AOTI to be true. + # _active_backend() does not allow DSO & AOTI to be true. # Choose either. from torchchat.utils.build_utils import set_backend set_backend (dso=True, pte=False, aoti_package=False) if (model.config != config): raise RuntimeError("loaded model architecture mismatch") - ## + ## ## import all libraries with custom kernels ans custom operators ## that quantize may be pulling in ## @@ -792,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: return "TikToken" if tokenizers: return "Tokenizers" - return "SentencePiece" \ No newline at end of file + return "SentencePiece" diff --git a/torchchat/export.py b/torchchat/export.py index a1dca61b2..28c9bdfec 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -482,7 +482,7 @@ def main(args): if tokenizer_args is None: tokenizer_type = "0" - elif tokenizer_args.is_sentencepiece: + elif tokenizer_args.is_sentencepiece(): tokenizer_type = "2" # Corresponding to llama2 else: tokenizer_type = "3" # Corresponding to llama3 diff --git a/torchchat/generate.py b/torchchat/generate.py index 45c868425..4f90b316f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -365,14 +365,14 @@ def __init__( # must use tiktokenizer. # Piggy backing off of this flag then for now to identify llama3 # without prompting user. - self.is_llama3_model = self.tokenizer_args.is_tiktoken + self.is_llama3_model = self.tokenizer_args.is_tiktoken() if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) - elif self.tokenizer_args.is_hf_tokenizer: + elif self.tokenizer_args.is_hf_tokenizer(): if not self.tokenizer.has_chat_template(): raise ValueError("Tokenizer must have a chat template") self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)