From 98eaf8f1d0a7557340d26b206f34f45dbcd6f16a Mon Sep 17 00:00:00 2001 From: zhenyanzhang Date: Fri, 25 Apr 2025 10:56:35 -0700 Subject: [PATCH 1/7] Simplify `TokenizerArgs.__post_init__` with Enum Tokenizer Type Summary: Simplify `TokenizerArgs.__post_init__` with enum tokenizer type, since only one of the tokenizer type can be true. We want to touch as less code outside of `__post_init__` as possible at the moment. Test Plan: python torchchat.py generate llama2|llama3|granite-code Reviewers: @Jack-Khuu Subscribers: Issue: https://github.com/pytorch/torchchat/issues/1518 --- torchchat/cli/builder.py | 51 +++++++++++++++++++++++----------------- torchchat/export.py | 2 +- torchchat/generate.py | 4 ++-- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 12463dd78..1bdc24378 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -8,6 +8,7 @@ import os import sys from dataclasses import dataclass +from enum import Enum from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union @@ -237,13 +238,16 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args +class TokenizerType(Enum): + NONE = 0 + TIKTOKEN = 1 + SENTENCEPIECE = 2 + HF_TOKENIZER = 3 @dataclass class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None - is_sentencepiece: bool = False - is_tiktoken: bool = False - is_hf_tokenizer: bool = False + tokenizer_type: TokenizerType = TokenizerType.NONE t: Optional[Any] = None def __post_init__(self): @@ -251,9 +255,7 @@ def __post_init__(self): from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) - self.is_tiktoken = True - self.is_sentencepiece = False - self.is_hf_tokenizer = False + self.tokenizer_type = TokenizerType.TIKTOKEN return except: pass @@ -262,9 +264,7 @@ def __post_init__(self): from sentencepiece import SentencePieceProcessor self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) - self.is_tiktoken = False - self.is_sentencepiece = True - self.is_hf_tokenizer = False + self.tokenizer_type = TokenizerType.SENTENCEPIECE return except: pass @@ -273,19 +273,24 @@ def __post_init__(self): from tokenizer.hf_tokenizer import HFTokenizer self.t = HFTokenizer(str(self.tokenizer_path)) - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = True + self.tokenizer_type = TokenizerType.HF_TOKENIZER return except: pass - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = False + self.tokenizer_type = TokenizerType.NONE self.t = None return + def is_tiktoken(self) -> bool: + return self.tokenizer_type == TokenizerType.TIKTOKEN + + def is_sentencepiece(self) -> bool: + return self.tokenizer_type == TokenizerType.SENTENCEPIECE + + def is_hf_tokenizer(self) -> bool: + return self.tokenizer_type == TokenizerType.HF_TOKENIZER + def validate_model( self, model: Optional[Model], @@ -294,12 +299,14 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: + + is_tiktoken = self.is_tiktoken() + is_sentencepiece = self.is_sentencepiece() + is_hf_tokenizer = self.is_hf_tokenizer() + + if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.is_tiktoken - is_sentencepiece = self.is_sentencepiece - is_hf_tokenizer = self.is_hf_tokenizer use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) @@ -651,13 +658,13 @@ def do_nothing(max_batch_size, max_seq_length): model = torch.load(builder_args.snapshot_path, weights_only=False) except Exception: raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") - # _active_backend() does not allow DSO & AOTI to be true. + # _active_backend() does not allow DSO & AOTI to be true. # Choose either. from torchchat.utils.build_utils import set_backend set_backend (dso=True, pte=False, aoti_package=False) if (model.config != config): raise RuntimeError("loaded model architecture mismatch") - ## + ## ## import all libraries with custom kernels ans custom operators ## that quantize may be pulling in ## @@ -792,4 +799,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: return "TikToken" if tokenizers: return "Tokenizers" - return "SentencePiece" \ No newline at end of file + return "SentencePiece" diff --git a/torchchat/export.py b/torchchat/export.py index a1dca61b2..28c9bdfec 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -482,7 +482,7 @@ def main(args): if tokenizer_args is None: tokenizer_type = "0" - elif tokenizer_args.is_sentencepiece: + elif tokenizer_args.is_sentencepiece(): tokenizer_type = "2" # Corresponding to llama2 else: tokenizer_type = "3" # Corresponding to llama3 diff --git a/torchchat/generate.py b/torchchat/generate.py index 45c868425..4f90b316f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -365,14 +365,14 @@ def __init__( # must use tiktokenizer. # Piggy backing off of this flag then for now to identify llama3 # without prompting user. - self.is_llama3_model = self.tokenizer_args.is_tiktoken + self.is_llama3_model = self.tokenizer_args.is_tiktoken() if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) - elif self.tokenizer_args.is_hf_tokenizer: + elif self.tokenizer_args.is_hf_tokenizer(): if not self.tokenizer.has_chat_template(): raise ValueError("Tokenizer must have a chat template") self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer) From 63be93afc7f0904ba9f27b582f376f5535e42669 Mon Sep 17 00:00:00 2001 From: zhenyanzhang Date: Fri, 25 Apr 2025 10:56:35 -0700 Subject: [PATCH 2/7] Simplify `TokenizerArgs.__post_init__` with Enum Tokenizer Type Summary: Simplify `TokenizerArgs.__post_init__` with enum tokenizer type, since only one of the tokenizer type can be true. We want to touch as less code outside of `__post_init__` as possible at the moment. Test Plan: python torchchat.py generate llama2|llama3|granite-code Reviewers: @Jack-Khuu Subscribers: Issue: https://github.com/pytorch/torchchat/issues/1518 --- torchchat/cli/builder.py | 97 ++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 49 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1bdc24378..350d953b7 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,14 +16,14 @@ import torch._dynamo.config import torch._inductor.config import torch.distributed as dist +from torchchat.distributed.logging_utils import SingletonLogger -from torchchat.distributed.utils import( +from torchchat.distributed.utils import ( Color as color, CUDATrackTime, - init_distributed, GPUMemoryMonitor, + init_distributed, ) -from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -36,7 +36,6 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model - from torchtune.models.convert_weights import meta_to_tune from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE @@ -188,15 +187,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - 'math': torch.nn.attention.SDPBackend.MATH, - 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, - 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + "math": torch.nn.attention.SDPBackend.MATH, + "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, + "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and (args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention"): - print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") + if args.device == "cpu" and ( + args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention" + ): + print( + f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." + ) attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -238,11 +241,6 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args -class TokenizerType(Enum): - NONE = 0 - TIKTOKEN = 1 - SENTENCEPIECE = 2 - HF_TOKENIZER = 3 @dataclass class TokenizerArgs: @@ -278,19 +276,12 @@ def __post_init__(self): except: pass - self.tokenizer_type = TokenizerType.NONE + self.is_tiktoken = False + self.is_sentencepiece = False + self.is_hf_tokenizer = False self.t = None return - def is_tiktoken(self) -> bool: - return self.tokenizer_type == TokenizerType.TIKTOKEN - - def is_sentencepiece(self) -> bool: - return self.tokenizer_type == TokenizerType.SENTENCEPIECE - - def is_hf_tokenizer(self) -> bool: - return self.tokenizer_type == TokenizerType.HF_TOKENIZER - def validate_model( self, model: Optional[Model], @@ -299,22 +290,20 @@ def validate_model( if model is None: return - - is_tiktoken = self.is_tiktoken() - is_sentencepiece = self.is_sentencepiece() - is_hf_tokenizer = self.is_hf_tokenizer() - - if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1: + if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") + is_tiktoken = self.is_tiktoken + is_sentencepiece = self.is_sentencepiece + is_hf_tokenizer = self.is_hf_tokenizer use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) if ( - (is_tiktoken and not use_tiktoken) or - (is_hf_tokenizer and not use_hf_tokenizer) or - (is_sentencepiece and not use_sentencepiece) + (is_tiktoken and not use_tiktoken) + or (is_hf_tokenizer and not use_hf_tokenizer) + or (is_sentencepiece and not use_sentencepiece) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -512,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc + if hasattr(model, "model"): model.model = None gc.collect() @@ -569,6 +559,7 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -606,6 +597,7 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -657,12 +649,15 @@ def do_nothing(max_batch_size, max_seq_length): try: model = torch.load(builder_args.snapshot_path, weights_only=False) except Exception: - raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") + raise RuntimeError( + f"Failed to load torchchat snapshot {builder_args.snapshot_path}" + ) # _active_backend() does not allow DSO & AOTI to be true. # Choose either. from torchchat.utils.build_utils import set_backend - set_backend (dso=True, pte=False, aoti_package=False) - if (model.config != config): + + set_backend(dso=True, pte=False, aoti_package=False) + if model.config != config: raise RuntimeError("loaded model architecture mismatch") ## ## import all libraries with custom kernels ans custom operators @@ -680,7 +675,9 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + logger.info( + f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" + ) # Model-level config if builder_args.params_table: @@ -691,20 +688,16 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - #TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import ( - load_model_weights, - ) + # TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import load_model_weights # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", - (pp_degree, tp_degree), - mesh_dim_names=("pp", "tp") - ) + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -733,7 +726,13 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) + load_model_weights( + model, + builder_args.distribution_path, + device, + config, + builder_args.chpt_from, + ) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -747,7 +746,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill=1024 + seqlen_prefill = 1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) From 379c07b01aead89d58ab9f7487d8bb6daa6ebe77 Mon Sep 17 00:00:00 2001 From: zhenyan-zhang-meta Date: Fri, 25 Apr 2025 16:44:48 -0700 Subject: [PATCH 3/7] Add check no tokenizer --- torchchat/cli/builder.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 350d953b7..256d3d8d0 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -241,6 +241,11 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args +class TokenizerType(Enum): + NONE = 0 + TIKTOKEN = 1 + SENTENCEPIECE = 2 + HF_TOKENIZER = 3 @dataclass class TokenizerArgs: @@ -276,12 +281,24 @@ def __post_init__(self): except: pass - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = False - self.t = None return + def is_tiktoken(self) -> bool: + return self.tokenizer_type == TokenizerType.TIKTOKEN + + def is_sentencepiece(self) -> bool: + return self.tokenizer_type == TokenizerType.SENTENCEPIECE + + def is_hf_tokenizer(self) -> bool: + return self.tokenizer_type == TokenizerType.HF_TOKENIZER + + def is_tokenizer_none(self) -> bool: + if self.tokenizer_type != TokenizerType.NONE: + return False + + assert self.t is None, "tokenizer_type is NONE but t is not None" + return True + def validate_model( self, model: Optional[Model], @@ -290,12 +307,13 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: + if self.is_tokenizer_none(): raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.is_tiktoken - is_sentencepiece = self.is_sentencepiece - is_hf_tokenizer = self.is_hf_tokenizer + is_tiktoken = self.is_tiktoken() + is_sentencepiece = self.is_sentencepiece() + is_hf_tokenizer = self.is_hf_tokenizer() + use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) From b896262fca57a815067db2f17f153988442bf525 Mon Sep 17 00:00:00 2001 From: zhenyan-zhang-meta Date: Fri, 25 Apr 2025 16:48:58 -0700 Subject: [PATCH 4/7] Rollback to 98eaf8f --- torchchat/cli/builder.py | 83 ++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 50 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 256d3d8d0..1bdc24378 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,14 +16,14 @@ import torch._dynamo.config import torch._inductor.config import torch.distributed as dist -from torchchat.distributed.logging_utils import SingletonLogger -from torchchat.distributed.utils import ( +from torchchat.distributed.utils import( Color as color, CUDATrackTime, - GPUMemoryMonitor, init_distributed, + GPUMemoryMonitor, ) +from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -36,6 +36,7 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model + from torchtune.models.convert_weights import meta_to_tune from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE @@ -187,19 +188,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - "math": torch.nn.attention.SDPBackend.MATH, - "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, - "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + 'math': torch.nn.attention.SDPBackend.MATH, + 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, + 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and ( - args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention" - ): - print( - f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." - ) + if args.device == "cpu" and (args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention"): + print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -281,6 +278,8 @@ def __post_init__(self): except: pass + self.tokenizer_type = TokenizerType.NONE + self.t = None return def is_tiktoken(self) -> bool: @@ -292,13 +291,6 @@ def is_sentencepiece(self) -> bool: def is_hf_tokenizer(self) -> bool: return self.tokenizer_type == TokenizerType.HF_TOKENIZER - def is_tokenizer_none(self) -> bool: - if self.tokenizer_type != TokenizerType.NONE: - return False - - assert self.t is None, "tokenizer_type is NONE but t is not None" - return True - def validate_model( self, model: Optional[Model], @@ -307,21 +299,22 @@ def validate_model( if model is None: return - if self.is_tokenizer_none(): - raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken() is_sentencepiece = self.is_sentencepiece() is_hf_tokenizer = self.is_hf_tokenizer() + if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1: + raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") + use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) if ( - (is_tiktoken and not use_tiktoken) - or (is_hf_tokenizer and not use_hf_tokenizer) - or (is_sentencepiece and not use_sentencepiece) + (is_tiktoken and not use_tiktoken) or + (is_hf_tokenizer and not use_hf_tokenizer) or + (is_sentencepiece and not use_sentencepiece) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -519,7 +512,6 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc - if hasattr(model, "model"): model.model = None gc.collect() @@ -577,7 +569,6 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass - model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -615,7 +606,6 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass - model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -667,15 +657,12 @@ def do_nothing(max_batch_size, max_seq_length): try: model = torch.load(builder_args.snapshot_path, weights_only=False) except Exception: - raise RuntimeError( - f"Failed to load torchchat snapshot {builder_args.snapshot_path}" - ) + raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") # _active_backend() does not allow DSO & AOTI to be true. # Choose either. from torchchat.utils.build_utils import set_backend - - set_backend(dso=True, pte=False, aoti_package=False) - if model.config != config: + set_backend (dso=True, pte=False, aoti_package=False) + if (model.config != config): raise RuntimeError("loaded model architecture mismatch") ## ## import all libraries with custom kernels ans custom operators @@ -693,9 +680,7 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info( - f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" - ) + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") # Model-level config if builder_args.params_table: @@ -706,16 +691,20 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - # TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import load_model_weights + #TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import ( + load_model_weights, + ) # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") - ) + "cuda", + (pp_degree, tp_degree), + mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -744,13 +733,7 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights( - model, - builder_args.distribution_path, - device, - config, - builder_args.chpt_from, - ) + load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -764,7 +747,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill = 1024 + seqlen_prefill=1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) From c846de9dd784bbe2f510bb85a75bc59bdc47a70e Mon Sep 17 00:00:00 2001 From: zhenyan-zhang-meta Date: Fri, 25 Apr 2025 16:51:27 -0700 Subject: [PATCH 5/7] Add No Tokenizer Checker --- torchchat/cli/builder.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1bdc24378..6b4d58521 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -278,8 +278,6 @@ def __post_init__(self): except: pass - self.tokenizer_type = TokenizerType.NONE - self.t = None return def is_tiktoken(self) -> bool: @@ -291,6 +289,13 @@ def is_sentencepiece(self) -> bool: def is_hf_tokenizer(self) -> bool: return self.tokenizer_type == TokenizerType.HF_TOKENIZER + def is_tokenizer_none(self) -> bool: + if self.tokenizer_type != TokenizerType.NONE: + return False + + assert self.t is None, "tokenizer_type is NONE but t is not None" + return True + def validate_model( self, model: Optional[Model], @@ -299,14 +304,13 @@ def validate_model( if model is None: return + if self.is_tokenizer_none(): + raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken() is_sentencepiece = self.is_sentencepiece() is_hf_tokenizer = self.is_hf_tokenizer() - if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1: - raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) From c752a40f7c20976ff9848e71b52a1465b472e6f7 Mon Sep 17 00:00:00 2001 From: zhenyan-zhang-meta Date: Mon, 28 Apr 2025 10:34:01 -0700 Subject: [PATCH 6/7] Reply to nits --- torchchat/cli/builder.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 6b4d58521..f5f3647ee 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -278,8 +278,6 @@ def __post_init__(self): except: pass - return - def is_tiktoken(self) -> bool: return self.tokenizer_type == TokenizerType.TIKTOKEN @@ -290,11 +288,7 @@ def is_hf_tokenizer(self) -> bool: return self.tokenizer_type == TokenizerType.HF_TOKENIZER def is_tokenizer_none(self) -> bool: - if self.tokenizer_type != TokenizerType.NONE: - return False - - assert self.t is None, "tokenizer_type is NONE but t is not None" - return True + return self.tokenizer_type == TokenizerType.NONE def validate_model( self, From 03e2019528835ee11d54adab864156c1d5f39609 Mon Sep 17 00:00:00 2001 From: zhenyan-zhang-meta Date: Mon, 28 Apr 2025 10:35:34 -0700 Subject: [PATCH 7/7] Reply to nits --- torchchat/cli/builder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f5f3647ee..0661d08f5 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -287,9 +287,6 @@ def is_sentencepiece(self) -> bool: def is_hf_tokenizer(self) -> bool: return self.tokenizer_type == TokenizerType.HF_TOKENIZER - def is_tokenizer_none(self) -> bool: - return self.tokenizer_type == TokenizerType.NONE - def validate_model( self, model: Optional[Model], @@ -298,7 +295,7 @@ def validate_model( if model is None: return - if self.is_tokenizer_none(): + if self.tokenizer_type == TokenizerType.NONE: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken()