diff --git a/olmo/config.py b/olmo/config.py index b4b0576f9..f3768d18d 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -480,14 +480,20 @@ class SchedulerType(StrEnum): constant = "constant" +class SchedulerUnits(StrEnum): + steps = "steps" + tokens = "tokens" + + @dataclass class SchedulerConfig(BaseConfig): name: SchedulerType = SchedulerType.cosine_with_warmup - t_warmup: int = 100 - t_max: Optional[int] = None + units: SchedulerUnits = SchedulerUnits.steps + t_warmup: Union[int, float] = 100 + t_max: Optional[Union[int, float]] = None alpha_f: float = 0.1 - grad_clip_warmup_steps: Optional[int] = None + grad_clip_warmup_steps: Optional[Union[int, float]] = None """ The warmup period for which the max grad norm (or norm ratio) will be set to its warmup value of `max_grad_norm * grad_clip_warmup_factor`. diff --git a/olmo/optim.py b/olmo/optim.py index 711e6d889..91d535e72 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -720,36 +720,46 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler if sched_cfg.name == SchedulerType.cosine_with_warmup: return CosWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), alpha_f=sched_cfg.alpha_f, - t_max=sched_cfg.t_max, + t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), ) elif sched_cfg.name == SchedulerType.linear_with_warmup: return LinearWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), alpha_f=sched_cfg.alpha_f, - t_max=sched_cfg.t_max, + t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), ) elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup: return InvSqrtWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), ) elif sched_cfg.name == SchedulerType.max_scheduler: return MaxScheduler( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)), sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)), ) elif sched_cfg.name == SchedulerType.constant: return ConstantScheduler( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, ) else: diff --git a/olmo/train.py b/olmo/train.py index 2207b0552..d9710f453 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -26,6 +26,7 @@ from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( CheckpointType, + SchedulerUnits, ShardedCheckpointerType, SpeedMonitorConfig, TrainConfig, @@ -122,6 +123,14 @@ def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) return self.train_loader.dataset + @property + def tokens_per_batch(self) -> int: + return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length + + @property + def batches_per_epoch(self) -> int: + return self.dataset.total_size // self.cfg.global_train_batch_size + @property def max_epochs(self) -> int: if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"): @@ -137,21 +146,59 @@ def max_steps(self) -> int: if self.cfg.max_duration.endswith("T"): # convert to float *first* to handle scientific notation max_tokens = int(float(self.cfg.max_duration[:-1].strip())) - tokens_remaining = max_tokens - self.global_train_tokens_seen - tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length - steps_remaining = tokens_remaining // tokens_per_batch + tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0) + steps_remaining = tokens_remaining // self.tokens_per_batch return self.global_step + steps_remaining elif self.cfg.max_duration.endswith("ep"): max_epochs = int(self.cfg.max_duration[:-2].strip()) - examples_per_epoch = self.dataset.total_size - steps_per_epoch = examples_per_epoch // self.cfg.global_train_batch_size - return max_epochs * steps_per_epoch + return max_epochs * self.batches_per_epoch else: # convert to float *first* to handle scientific notation return int(float(self.cfg.max_duration)) else: raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}") + @property + def max_tokens(self) -> int: + if isinstance(self.cfg.max_duration, int): + return ( + self.global_train_tokens_seen + + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch + ) + elif isinstance(self.cfg.max_duration, str): + if self.cfg.max_duration.endswith("T"): + # convert to float *first* to handle scientific notation + return int(float(self.cfg.max_duration[:-1].strip())) + elif self.cfg.max_duration.endswith("ep"): + max_epochs = int(self.cfg.max_duration[:-2].strip()) + return max_epochs * self.batches_per_epoch * self.tokens_per_batch + else: + # convert to float *first* to handle scientific notation + return ( + self.global_train_tokens_seen + + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch + ) + else: + raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}") + + @property + def scheduler_current(self) -> int: + if self.cfg.scheduler.units == SchedulerUnits.steps: + return self.global_step + elif self.cfg.scheduler.units == SchedulerUnits.tokens: + return self.global_train_tokens_seen + else: + raise NotImplementedError(self.cfg.scheduler.units) + + @property + def scheduler_max(self) -> int: + if self.cfg.scheduler.units == SchedulerUnits.steps: + return self.max_steps + elif self.cfg.scheduler.units == SchedulerUnits.tokens: + return self.max_tokens + else: + raise NotImplementedError(self.cfg.scheduler.units) + def trainer_state_dict(self) -> Dict[str, Any]: return { "epoch": self.epoch, @@ -233,7 +280,7 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: # Reset learning rate and weight decay to the values from the config, not the checkpoint. log.info("Resetting learning rate...") new_learning_rate = self.scheduler.get_lr( - self.cfg.optimizer.learning_rate, self.global_step, self.max_steps + self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max ) for group in self.optim.param_groups: group["lr"] = new_learning_rate @@ -572,12 +619,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of # the corresponding values from `self.cfg`. - group["lr"] = self.scheduler.get_lr(self.cfg.optimizer.learning_rate, self.global_step, self.max_steps) + group["lr"] = self.scheduler.get_lr( + self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max + ) group["max_grad_norm"] = self.scheduler.get_max_grad_norm( - self.cfg.max_grad_norm, self.global_step, self.max_steps + self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max ) group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm( - self.cfg.max_grad_norm_ratio, self.global_step, self.max_steps + self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max ) # Optimizer step. diff --git a/scripts/train.py b/scripts/train.py index 710cf0255..de97e31be 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -212,7 +212,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: trainer.scheduler = BoltOnWarmupScheduler.wrap( trainer.scheduler, trainer.global_step, - trainer.global_step + cfg.scheduler.t_warmup, + int(trainer.global_step + cfg.scheduler.t_warmup), ) if cfg.force_save_unsharded: