Skip to content

Commit

Permalink
Merge pull request #35 from YerevaNN/wsd_s
Browse files Browse the repository at this point in the history
add WSD-S scheduler
  • Loading branch information
MenuaB authored Jan 20, 2025
2 parents 60dbe30 + 793f763 commit 642cbcc
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 57 deletions.
3 changes: 2 additions & 1 deletion submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
for _ in range(1):
# train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = "./train_configs/llama3.2_3b.toml"
train_config = "./train_configs/llama3.2_1b.toml"
# train_config = "./train_configs/llama3.2_3b.toml"
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction(
[
Expand Down
76 changes: 47 additions & 29 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging import logger, validate_log_level

from typing import Optional

from torchtitan.logging import logger, validate_log_level

TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


def string_list(raw_arg):
return raw_arg.split(",")

Expand Down Expand Up @@ -181,7 +182,10 @@ def __init__(self):
"--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
)
self.parser.add_argument(
"--optimizer.schedule", type=str, default="Linear", help="Optimization schedule to use"
"--optimizer.schedule",
type=str,
default="Linear",
help="Optimization schedule to use",
)
self.parser.add_argument(
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
Expand Down Expand Up @@ -225,17 +229,29 @@ def __init__(self):
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
)
self.parser.add_argument(
"--training.decay_steps",
type=Optional[int],
default=None,
help="Steps for lr scheduler decay, default is decay starts immediately after warmup",
"--training.num_decays",
type=Optional[float],
default=1,
help="The number of total decays to perform throughout the training, following the WSD-S scheduler",
)
self.parser.add_argument(
"--training.decay_type",
type=str,
default="linear",
choices = ["linear","cosine"],
help="Steps for lr scheduler decay type, defaults to linear",
"--training.decay_steps",
type=Optional[int],
default=None,
help="Steps for lr scheduler decay, default is decay starts immediately after warmup",
)
self.parser.add_argument(
"--training.decay_steps_perc",
type=Optional[float],
default=1.0,
help="The percentage of the steps to use as decay steps",
)
self.parser.add_argument(
"--training.decay_type",
type=str,
default="linear",
choices=["linear", "cosine"],
help="Steps for lr scheduler decay type, defaults to linear",
)
self.parser.add_argument(
"--training.max_norm",
Expand Down Expand Up @@ -266,7 +282,7 @@ def __init__(self):
default=True,
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
)
self.parser.add_argument(
"--training.representation_type",
default="SMILES",
Expand Down Expand Up @@ -387,9 +403,7 @@ def __init__(self):
)

# validation configs
self.parser.add_argument(
"--validation.batch_size", type=int, default=None
)
self.parser.add_argument("--validation.batch_size", type=int, default=None)
self.parser.add_argument(
"--validation.dataset", type=str, help="Dataset to use", default=None
)
Expand All @@ -402,10 +416,16 @@ def __init__(self):
default=None,
)
self.parser.add_argument(
"--validation.valid_freq", type=int, default=1024, help="How often to evaluate the model and log metrics to aim."
"--validation.valid_freq",
type=int,
default=1024,
help="How often to evaluate the model and log metrics to aim.",
)
self.parser.add_argument(
"--validation.enable_valid", type=bool, default=False, help="Whether to do validation."
"--validation.enable_valid",
type=bool,
default=False,
help="Whether to do validation.",
)

# checkpointing configs
Expand Down Expand Up @@ -647,35 +667,33 @@ def __init__(self):
)
self.parser.add_argument(
"--logging.log_level",
default = "INFO",
default="INFO",
choices=["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"],
type=str,
help="Set the log level, INFO by default"
help="Set the log level, INFO by default",
)
self.parser.add_argument(
"--dataloader.num_workers",
default = 0,
default=0,
type=int,
help="""Set the number of dataloader workers PER RANK, default is 0. 1 is non-blocking.
More than 1 may lead to issues with data splitting / duplication"""
More than 1 may lead to issues with data splitting / duplication""",
)
self.parser.add_argument(
"--dataloader.pin_memory",
default = False,
default=False,
type=bool,
help= "Whether or not to pin dataloader memory"
help="Whether or not to pin dataloader memory",
)

self.parser.add_argument(
"--dataloader.special_mode",
default = None,
choices = ["yield_tensor"],
default=None,
choices=["yield_tensor"],
type=str,
help= "Enable a special dataloading mode, useful for debugging"
help="Enable a special dataloading mode, useful for debugging",
)



def parse_args(self, args_list: list = sys.argv[1:]):
self.args_list = args_list

Expand Down
87 changes: 75 additions & 12 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import functools
import math
from enum import Enum

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig
from enum import Enum


def build_optimizers(model_parts, job_config: JobConfig):
Expand Down Expand Up @@ -57,12 +57,14 @@ def zero_grad(self):

return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def linear_warmup(warmup_steps: int, current_step: int) -> float:
"""Computes the linear warmup scaling factor."""
if warmup_steps <= 0:
raise ValueError("warmup_steps must be positive.")
return float((current_step + 1) / (warmup_steps + 1))


# Decay functions
def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float:
"""Computes the linear decay scaling factor."""
Expand All @@ -71,61 +73,121 @@ def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float:
progress = float((current_step - start_step) / decay_steps)
return max(0.0, 1 - progress)


def cosine_decay(decay_steps: int, current_step: int, start_step: int) -> float:
"""Computes the cosine decay scaling factor."""
if decay_steps <= 0:
raise ValueError("decay_steps must be positive.")
current_step = min(current_step - start_step, decay_steps)
return 0.5 * (1 + math.cos(math.pi * current_step / decay_steps))


class Decay(Enum):
LINEAR = functools.partial(linear_decay)
COSINE = functools.partial(cosine_decay)

@staticmethod
def from_string(decay_type: str) -> 'Decay':
def from_string(decay_type: str) -> "Decay":
"""Converts a string to the corresponding Decay enum value."""
try:
return Decay[decay_type.upper()]
except KeyError:
raise ValueError(f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}")
except KeyError as e:
raise ValueError(
f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}"
) from e


def warmup_stable_decay(
decay_type: Decay, warmup_steps: int, decay_steps: int,training_steps:int, current_step: int
decay_type: Decay,
warmup_steps: int,
decay_steps: int,
training_steps: int,
current_step: int,
) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
start_decay_step = training_steps-decay_steps
start_decay_step = training_steps - decay_steps

if current_step < warmup_steps:
# warmup phase
curr_adjustment = linear_warmup(warmup_steps,current_step)
return linear_warmup(warmup_steps,current_step)
curr_adjustment = linear_warmup(warmup_steps, current_step)
return linear_warmup(warmup_steps, current_step)

elif (current_step >= warmup_steps) and (current_step<start_decay_step):
elif (current_step >= warmup_steps) and (current_step < start_decay_step):
# stable phase, no adjustment to lr
return 1.0

else:
# decay phase supporting multiple decay functions
return decay_type.value(decay_steps, current_step, start_decay_step)


# implementation of WSD-S scheduler
def warmup_stable_decay_simplified(
decay_type: Decay,
warmup_steps: int,
decay_steps_perc: float,
num_decays: int,
training_steps: int,
current_step: int,
) -> float:
# num steps for each decay
per_decay_num_steps = training_steps // num_decays
# current decay index
decay_index = math.ceil(current_step / per_decay_num_steps)
# the step at which lr is decayed
decay_at_step = decay_index * per_decay_num_steps
# number of decay steps
if decay_index == 1:
# make sure the decay_steps_perc does not include the warmup_steps
decay_steps_perc = min(decay_steps_perc, 1 - warmup_steps / decay_at_step)

decay_steps = int(decay_at_step * decay_steps_perc)
# the step at which to start the decay
start_decay_step = decay_at_step - decay_steps

if current_step < warmup_steps:
# warmup phase
curr_adjustment = current_step / warmup_steps
elif current_step < start_decay_step:
# stable phase, no adjustment to lr
curr_adjustment = 1.0
else:
# decay phase supporting multiple decay functions
curr_adjustment = decay_type.value(decay_steps, current_step, start_decay_step)

return curr_adjustment


def build_lr_schedulers(optimizers, job_config: JobConfig) -> LambdaLR:
def _build_lr_scheduler(optimizer):
"""Build a linear warmup optionally stable and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
post_warmup_steps = float(max(1, job_config.training.steps - warmup_steps))

# If decay steps is not set in config, decay will begin immediately after warmup
decay_steps = job_config.training.decay_steps if job_config.training.decay_steps else post_warmup_steps
decay_steps = (
job_config.training.decay_steps
if job_config.training.decay_steps
else post_warmup_steps
)
decay_steps_perc = job_config.training.decay_steps_perc
num_decays = job_config.training.num_decays
decay_type = Decay.from_string(job_config.training.decay_type)

# lr_lambda = functools.partial(
# warmup_stable_decay, decay_type, warmup_steps, decay_steps, job_config.training.steps
# )
lr_lambda = functools.partial(
warmup_stable_decay, decay_type ,warmup_steps, decay_steps, job_config.training.steps
warmup_stable_decay_simplified,
decay_type,
warmup_steps,
decay_steps_perc,
num_decays,
job_config.training.steps,
)
warmup_stable_decay_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_stable_decay_scheduler
Expand All @@ -139,7 +201,8 @@ def __init__(self, schedulers):
def step(self):
for schedulers in self.schedulers:
schedulers.step()
@property

@property
def last_lr(self):
return self.schedulers[0].get_last_lr()[0]

Expand Down
12 changes: 7 additions & 5 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ lr = 8e-4

[training]
batch_size = 1
gradient_accumulation_steps = 48
gradient_accumulation_steps = 1
seq_len = 2048
warmup_steps = 20 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 100
warmup_steps = 5 # lr scheduler warm up, normally 20% of the train steps
steps = 200
decay_steps_perc = 0.1
num_decays = 4
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
Expand All @@ -56,12 +58,12 @@ enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = false
# load_folder = "yerevann/Llama-debug/b00ef18db9d447ff84b9035a"
# load_folder = "yerevann/Llama-debug/bab005ed36ef4e02a3e62333"
save_folder = "yerevann/Llama-debug"
# load_at_step = 100
create_seed_checkpoint = false
interval_type = "steps"
interval = 50
interval = 100
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down
Loading

0 comments on commit 642cbcc

Please sign in to comment.