Skip to content

Simplify TokenizerArgs.__post_init__ with Enum Tokenizer Type #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args.pte_path = None
return speculative_builder_args

class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3

@dataclass
class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_hf_tokenizer: bool = False
tokenizer_type: TokenizerType = TokenizerType.NONE
t: Optional[Any] = None

def __post_init__(self):
try:
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer

self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.tokenizer_type = TokenizerType.TIKTOKEN
return
except:
pass
Expand All @@ -262,9 +264,7 @@ def __post_init__(self):
from sentencepiece import SentencePieceProcessor

self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_hf_tokenizer = False
self.tokenizer_type = TokenizerType.SENTENCEPIECE
return
except:
pass
Expand All @@ -273,18 +273,19 @@ def __post_init__(self):
from tokenizer.hf_tokenizer import HFTokenizer

self.t = HFTokenizer(str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = True
self.tokenizer_type = TokenizerType.HF_TOKENIZER
return
except:
pass

self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.t = None
return
def is_tiktoken(self) -> bool:
return self.tokenizer_type == TokenizerType.TIKTOKEN

def is_sentencepiece(self) -> bool:
return self.tokenizer_type == TokenizerType.SENTENCEPIECE

def is_hf_tokenizer(self) -> bool:
return self.tokenizer_type == TokenizerType.HF_TOKENIZER

def validate_model(
self,
Expand All @@ -294,12 +295,13 @@ def validate_model(
if model is None:
return

if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
if self.tokenizer_type == TokenizerType.NONE:
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_hf_tokenizer = self.is_hf_tokenizer
is_tiktoken = self.is_tiktoken()
is_sentencepiece = self.is_sentencepiece()
is_hf_tokenizer = self.is_hf_tokenizer()

use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
Expand Down Expand Up @@ -651,13 +653,13 @@ def do_nothing(max_batch_size, max_seq_length):
model = torch.load(builder_args.snapshot_path, weights_only=False)
except Exception:
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
# _active_backend() does not allow DSO & AOTI to be true.
# _active_backend() does not allow DSO & AOTI to be true.
# Choose either.
from torchchat.utils.build_utils import set_backend
set_backend (dso=True, pte=False, aoti_package=False)
if (model.config != config):
raise RuntimeError("loaded model architecture mismatch")
##
##
## import all libraries with custom kernels ans custom operators
## that quantize may be pulling in
##
Expand Down Expand Up @@ -792,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
return "SentencePiece"
2 changes: 1 addition & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def main(args):

if tokenizer_args is None:
tokenizer_type = "0"
elif tokenizer_args.is_sentencepiece:
elif tokenizer_args.is_sentencepiece():
tokenizer_type = "2" # Corresponding to llama2
else:
tokenizer_type = "3" # Corresponding to llama3
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ def __init__(
# must use tiktokenizer.
# Piggy backing off of this flag then for now to identify llama3
# without prompting user.
self.is_llama3_model = self.tokenizer_args.is_tiktoken
self.is_llama3_model = self.tokenizer_args.is_tiktoken()
if self.is_llama3_model:
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
if generator_args.chat_mode:
logger.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
elif self.tokenizer_args.is_hf_tokenizer:
elif self.tokenizer_args.is_hf_tokenizer():
if not self.tokenizer.has_chat_template():
raise ValueError("Tokenizer must have a chat template")
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
Expand Down