From 798abbff3a63900d267970d23fe9ea471bc0fc2d Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 28 Dec 2024 23:24:21 +0400 Subject: [PATCH 1/3] add WSD-S scheduler --- torchtitan/config_manager.py | 76 +++++++++++++++++------------ torchtitan/optimizer.py | 87 +++++++++++++++++++++++++++++----- train_configs/debug_model.toml | 6 ++- train_configs/llama3.2_1b.toml | 2 +- 4 files changed, 127 insertions(+), 44 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index bdadb318..ae922f55 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -16,16 +16,17 @@ except ModuleNotFoundError: import tomli as tomllib -from torchtitan.logging import logger, validate_log_level - from typing import Optional +from torchtitan.logging import logger, validate_log_level + TORCH_DTYPE_MAP = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } + def string_list(raw_arg): return raw_arg.split(",") @@ -181,7 +182,10 @@ def __init__(self): "--optimizer.name", type=str, default="AdamW", help="Optimizer to use" ) self.parser.add_argument( - "--optimizer.schedule", type=str, default="Linear", help="Optimization schedule to use" + "--optimizer.schedule", + type=str, + default="Linear", + help="Optimization schedule to use", ) self.parser.add_argument( "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use" @@ -225,17 +229,29 @@ def __init__(self): help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", ) self.parser.add_argument( - "--training.decay_steps", - type=Optional[int], - default=None, - help="Steps for lr scheduler decay, default is decay starts immediately after warmup", + "--training.num_decays", + type=Optional[float], + default=1, + help="The number of total decays to perform throughout the training, following the WSD-S scheduler", ) self.parser.add_argument( - "--training.decay_type", - type=str, - default="linear", - choices = ["linear","cosine"], - help="Steps for lr scheduler decay type, defaults to linear", + "--training.decay_steps", + type=Optional[int], + default=None, + help="Steps for lr scheduler decay, default is decay starts immediately after warmup", + ) + self.parser.add_argument( + "--training.decay_steps_perc", + type=Optional[float], + default=1.0, + help="The percentage of the steps to use as decay steps", + ) + self.parser.add_argument( + "--training.decay_type", + type=str, + default="linear", + choices=["linear", "cosine"], + help="Steps for lr scheduler decay type, defaults to linear", ) self.parser.add_argument( "--training.max_norm", @@ -266,7 +282,7 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", - ) + ) self.parser.add_argument( "--training.representation_type", default="SMILES", @@ -387,9 +403,7 @@ def __init__(self): ) # validation configs - self.parser.add_argument( - "--validation.batch_size", type=int, default=None - ) + self.parser.add_argument("--validation.batch_size", type=int, default=None) self.parser.add_argument( "--validation.dataset", type=str, help="Dataset to use", default=None ) @@ -402,10 +416,16 @@ def __init__(self): default=None, ) self.parser.add_argument( - "--validation.valid_freq", type=int, default=1024, help="How often to evaluate the model and log metrics to aim." + "--validation.valid_freq", + type=int, + default=1024, + help="How often to evaluate the model and log metrics to aim.", ) self.parser.add_argument( - "--validation.enable_valid", type=bool, default=False, help="Whether to do validation." + "--validation.enable_valid", + type=bool, + default=False, + help="Whether to do validation.", ) # checkpointing configs @@ -647,35 +667,33 @@ def __init__(self): ) self.parser.add_argument( "--logging.log_level", - default = "INFO", + default="INFO", choices=["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"], type=str, - help="Set the log level, INFO by default" + help="Set the log level, INFO by default", ) self.parser.add_argument( "--dataloader.num_workers", - default = 0, + default=0, type=int, help="""Set the number of dataloader workers PER RANK, default is 0. 1 is non-blocking. - More than 1 may lead to issues with data splitting / duplication""" + More than 1 may lead to issues with data splitting / duplication""", ) self.parser.add_argument( "--dataloader.pin_memory", - default = False, + default=False, type=bool, - help= "Whether or not to pin dataloader memory" + help="Whether or not to pin dataloader memory", ) self.parser.add_argument( "--dataloader.special_mode", - default = None, - choices = ["yield_tensor"], + default=None, + choices=["yield_tensor"], type=str, - help= "Enable a special dataloading mode, useful for debugging" + help="Enable a special dataloading mode, useful for debugging", ) - - def parse_args(self, args_list: list = sys.argv[1:]): self.args_list = args_list diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 5f180aa5..19464c37 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -6,11 +6,11 @@ import functools import math +from enum import Enum import torch from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig -from enum import Enum def build_optimizers(model_parts, job_config: JobConfig): @@ -57,12 +57,14 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) + def linear_warmup(warmup_steps: int, current_step: int) -> float: """Computes the linear warmup scaling factor.""" if warmup_steps <= 0: raise ValueError("warmup_steps must be positive.") return float((current_step + 1) / (warmup_steps + 1)) + # Decay functions def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float: """Computes the linear decay scaling factor.""" @@ -71,6 +73,7 @@ def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float: progress = float((current_step - start_step) / decay_steps) return max(0.0, 1 - progress) + def cosine_decay(decay_steps: int, current_step: int, start_step: int) -> float: """Computes the cosine decay scaling factor.""" if decay_steps <= 0: @@ -78,35 +81,42 @@ def cosine_decay(decay_steps: int, current_step: int, start_step: int) -> float: current_step = min(current_step - start_step, decay_steps) return 0.5 * (1 + math.cos(math.pi * current_step / decay_steps)) + class Decay(Enum): LINEAR = functools.partial(linear_decay) COSINE = functools.partial(cosine_decay) @staticmethod - def from_string(decay_type: str) -> 'Decay': + def from_string(decay_type: str) -> "Decay": """Converts a string to the corresponding Decay enum value.""" try: return Decay[decay_type.upper()] - except KeyError: - raise ValueError(f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}") + except KeyError as e: + raise ValueError( + f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}" + ) from e def warmup_stable_decay( - decay_type: Decay, warmup_steps: int, decay_steps: int,training_steps:int, current_step: int + decay_type: Decay, + warmup_steps: int, + decay_steps: int, + training_steps: int, + current_step: int, ) -> float: """Computes linear warmup followed by linear decay. Per LambdaLR requirement, this is accomplished by returning a multiplicative factor to adjust the learning rate to create the desired schedule. """ - start_decay_step = training_steps-decay_steps + start_decay_step = training_steps - decay_steps if current_step < warmup_steps: # warmup phase - curr_adjustment = linear_warmup(warmup_steps,current_step) - return linear_warmup(warmup_steps,current_step) + curr_adjustment = linear_warmup(warmup_steps, current_step) + return linear_warmup(warmup_steps, current_step) - elif (current_step >= warmup_steps) and (current_step= warmup_steps) and (current_step < start_decay_step): # stable phase, no adjustment to lr return 1.0 @@ -114,6 +124,44 @@ def warmup_stable_decay( # decay phase supporting multiple decay functions return decay_type.value(decay_steps, current_step, start_decay_step) + +# implementation of WSD-S scheduler +def warmup_stable_decay_simplified( + decay_type: Decay, + warmup_steps: int, + decay_steps_perc: float, + num_decays: int, + training_steps: int, + current_step: int, +) -> float: + # num steps for each decay + per_decay_num_steps = training_steps // num_decays + # current decay index + decay_index = math.ceil(current_step / per_decay_num_steps) + # the step at which lr is decayed + decay_at_step = decay_index * per_decay_num_steps + # number of decay steps + if decay_index == 1: + # make sure the decay_steps_perc does not include the warmup_steps + decay_steps_perc = min(decay_steps_perc, 1 - warmup_steps / decay_at_step) + + decay_steps = int(decay_at_step * decay_steps_perc) + # the step at which to start the decay + start_decay_step = decay_at_step - decay_steps + + if current_step < warmup_steps: + # warmup phase + curr_adjustment = current_step / warmup_steps + elif current_step < start_decay_step: + # stable phase, no adjustment to lr + curr_adjustment = 1.0 + else: + # decay phase supporting multiple decay functions + curr_adjustment = decay_type.value(decay_steps, current_step, start_decay_step) + + return curr_adjustment + + def build_lr_schedulers(optimizers, job_config: JobConfig) -> LambdaLR: def _build_lr_scheduler(optimizer): """Build a linear warmup optionally stable and linear decay scheduler""" @@ -121,11 +169,25 @@ def _build_lr_scheduler(optimizer): post_warmup_steps = float(max(1, job_config.training.steps - warmup_steps)) # If decay steps is not set in config, decay will begin immediately after warmup - decay_steps = job_config.training.decay_steps if job_config.training.decay_steps else post_warmup_steps + decay_steps = ( + job_config.training.decay_steps + if job_config.training.decay_steps + else post_warmup_steps + ) + decay_steps_perc = job_config.training.decay_steps_perc + num_decays = job_config.training.num_decays decay_type = Decay.from_string(job_config.training.decay_type) + # lr_lambda = functools.partial( + # warmup_stable_decay, decay_type, warmup_steps, decay_steps, job_config.training.steps + # ) lr_lambda = functools.partial( - warmup_stable_decay, decay_type ,warmup_steps, decay_steps, job_config.training.steps + warmup_stable_decay_simplified, + decay_type, + warmup_steps, + decay_steps_perc, + num_decays, + job_config.training.steps, ) warmup_stable_decay_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return warmup_stable_decay_scheduler @@ -139,7 +201,8 @@ def __init__(self, schedulers): def step(self): for schedulers in self.schedulers: schedulers.step() - @property + + @property def last_lr(self): return self.schedulers[0].get_last_lr()[0] diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 56ba99ae..6bfedb7f 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -30,10 +30,12 @@ lr = 8e-4 [training] batch_size = 1 -gradient_accumulation_steps = 48 +gradient_accumulation_steps = 1 seq_len = 2048 -warmup_steps = 20 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping +warmup_steps = 5 # lr scheduler warm up, normally 20% of the train steps +decay_steps_perc = 0.1 +num_decays = 4 steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 1 diff --git a/train_configs/llama3.2_1b.toml b/train_configs/llama3.2_1b.toml index b9691e19..a5c14187 100644 --- a/train_configs/llama3.2_1b.toml +++ b/train_configs/llama3.2_1b.toml @@ -32,8 +32,8 @@ lr = 6e-4 batch_size = 10 gradient_accumulation_steps = 16 seq_len = 2048 -warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping +warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps steps = 40000 data_parallel_degree = -1 tensor_parallel_degree = 1 From aa42bb177c9e1b4bac017282d84e68509c33d06e Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sun, 29 Dec 2024 00:51:44 +0400 Subject: [PATCH 2/3] update llama-3.2-1b config to use WSD-S scheduler --- train_configs/debug_model.toml | 6 +++--- train_configs/llama3.2_1b.toml | 6 +++--- train_configs/llama3.2_3b.toml | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 6bfedb7f..86672447 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -34,9 +34,9 @@ gradient_accumulation_steps = 1 seq_len = 2048 max_norm = 1.0 # grad norm clipping warmup_steps = 5 # lr scheduler warm up, normally 20% of the train steps +steps = 200 decay_steps_perc = 0.1 num_decays = 4 -steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = true @@ -58,12 +58,12 @@ enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false -# load_folder = "yerevann/Llama-debug/b00ef18db9d447ff84b9035a" +# load_folder = "yerevann/Llama-debug/bab005ed36ef4e02a3e62333" save_folder = "yerevann/Llama-debug" # load_at_step = 100 create_seed_checkpoint = false interval_type = "steps" -interval = 50 +interval = 100 model_weights_only = false export_dtype = "float32" async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/train_configs/llama3.2_1b.toml b/train_configs/llama3.2_1b.toml index a5c14187..37635de7 100644 --- a/train_configs/llama3.2_1b.toml +++ b/train_configs/llama3.2_1b.toml @@ -34,12 +34,12 @@ gradient_accumulation_steps = 16 seq_len = 2048 max_norm = 1.0 # grad norm clipping warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps -steps = 40000 +steps = 80000 +decay_steps_perc = 0.1 +num_decays = 4 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = true -# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) -# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) dataset = "chemlactica_train" data_processing_style="chemlactica_style" representation_type = "SMILES" diff --git a/train_configs/llama3.2_3b.toml b/train_configs/llama3.2_3b.toml index eb394458..8b3993c2 100644 --- a/train_configs/llama3.2_3b.toml +++ b/train_configs/llama3.2_3b.toml @@ -38,8 +38,6 @@ steps = 40000 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = true -# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) -# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) dataset = "chemlactica_train" data_processing_style="chemlactica_style" representation_type = "SMILES" From 793f7636911ee1d2c4a4c80542ad98ff0c1330db Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sun, 12 Jan 2025 13:12:46 +0400 Subject: [PATCH 3/3] set the config of llama-1b model to use WSD --- submitit_train.py | 3 ++- train_configs/llama3.2_1b.toml | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/submitit_train.py b/submitit_train.py index 8bb54bfe..1a762bba 100644 --- a/submitit_train.py +++ b/submitit_train.py @@ -26,7 +26,8 @@ for _ in range(1): # train_config = './train_configs/chemlactica_125m.toml' # train_config = './train_configs/chemlactica_1.3b.toml' - train_config = "./train_configs/llama3.2_3b.toml" + train_config = "./train_configs/llama3.2_1b.toml" + # train_config = "./train_configs/llama3.2_3b.toml" # train_config = './train_configs/debug_model.toml' function = submitit.helpers.CommandFunction( [ diff --git a/train_configs/llama3.2_1b.toml b/train_configs/llama3.2_1b.toml index 37635de7..60824996 100644 --- a/train_configs/llama3.2_1b.toml +++ b/train_configs/llama3.2_1b.toml @@ -26,7 +26,7 @@ tokenizer_path = "torchtitan/tokenizers/Llama-3.2-chem-1B-v1/" [optimizer] name = "AdamW" -lr = 6e-4 +lr = 4e-4 [training] batch_size = 10 @@ -36,12 +36,12 @@ max_norm = 1.0 # grad norm clipping warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps steps = 80000 decay_steps_perc = 0.1 -num_decays = 4 +num_decays = 1 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = true dataset = "chemlactica_train" -data_processing_style="chemlactica_style" +data_processing_style = "chemlactica_style" representation_type = "SMILES" [validation] @@ -60,8 +60,8 @@ enable_async_tensor_parallel = false enable_checkpoint = true save_folder = "yerevann/Llama-3.2-1B" load_folder = "meta-llama/Llama-3.2-1B" -# load_folder = "yerevann/Llama-3.2-1B/ec943c9e63db4cf7b4a8b847" -# load_at_step = 40000 +# load_folder = "yerevann/Llama-3.2-1B/c08769a3ed064389838fd8a5" +# load_at_step = 32000 interval_type = "steps" interval = 2000 model_weights_only = false