From 009018e9aca7a5c926ec3e7d540e6b0504458807 Mon Sep 17 00:00:00 2001 From: Kshitij Gupta Date: Mon, 7 Aug 2023 12:10:28 -0400 Subject: [PATCH] Fixed AnnealingLR Class and Cosine Decay Schedule (#1008) * Fixed AnnealingLR Class and Cosine Decay Schedule * Update NeoXArgs docs automatically --------- Co-authored-by: github-actions --- configs/neox_arguments.md | 2 +- megatron/learning_rates.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index dc1f42a3d..c50e7ff01 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = e16af33 + Default = d3e481c current git hash of repository diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 943efdf1a..cbf2cc2fc 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -64,7 +64,7 @@ def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" - num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) + num_iters_ = self.num_iters # Warmup. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter @@ -73,8 +73,8 @@ def get_lr(self): if self.decay_style == "linear": lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter elif self.decay_style == "cosine": - lr = ( - self.start_lr + lr = self.min_lr + ( + (self.start_lr-self.min_lr) / 2.0 * (math.cos(math.pi * num_iters_ / self.end_iter) + 1) )