Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements cosine decay with linear warmup #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions lib/polaris/schedules.ex
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,65 @@ defmodule Polaris.Schedules do
init_value * (cos + alpha)
end

@doc """
Cosine decay schedule with linear warmup.

For steps within the warmup period, the learning rate increases linearly. After the warmup period, it follows a cosine decay schedule:

$$
\gamma(t) =
\begin{cases}
\frac{t}{\text{warmup\_steps}} * \gamma_{\text{peak}} & \text{if } t < \text{warmup\_steps} \\
\gamma_{\text{peak}} * \left(\frac{1}{2}(1 - \alpha)(1 + \cos\pi \frac{t - \text{warmup\_steps}}{k}) + \alpha\right) & \text{otherwise}
\end{cases}
$$

## Options

* `:warmup_steps` - number of warmup steps during which the learning rate increases linearly.

* `:decay_steps` - number of steps to apply decay for.
$k$ in the above formulation. Defaults to `10`.

* `:alpha` - minimum multiplier value for adjusting the learning rate.
$\alpha$ in the above formulation. Defaults to `0.0`.
"""
def warmup_cosine_decay(init_value, opts \\ []) do
&apply_warmup_cosine_decay(&1, [{:init_value, init_value} | opts])
end

defnp apply_warmup_cosine_decay(step, opts \\ []) do
opts =
keyword!(opts,
init_value: nil,
warmup_steps: 10,
decay_steps: 10,
alpha: 0.0
)

init_value = opts[:init_value]
warmup_steps = opts[:warmup_steps]
decay_steps = opts[:decay_steps]
alpha = opts[:alpha]

# Linear warmup phase
warmup_rate = init_value / warmup_steps
warmup = warmup_rate * Nx.min(step, warmup_steps)

# Cosine decay phase
decay_step = Nx.max(step - warmup_steps, 0)
theta = Nx.min(decay_step, decay_steps) / decay_steps * Nx.Constants.pi()
cos_decay = (Nx.cos(theta) + 1) / 2
decay = init_value * (cos_decay * (1 - alpha) + alpha)

# Choose between warmup and decay based on step
Nx.select(
Nx.less(step, warmup_steps),
warmup,
decay
)
end

@doc ~S"""
Constant schedule.

Expand Down
55 changes: 55 additions & 0 deletions test/polaris/schedules_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,61 @@ defmodule Polaris.SchedulesTest do
end
end

describe "warmup_cosine_decay" do
test "returns arity-1 function with required options" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10)
assert is_function(fun, 1)
end

test "returns arity-1 function with additional options" do
fun = warmup_cosine_decay(1.0e-3, warmup_steps: 5, decay_steps: 10, alpha: 0.1)
assert is_function(fun, 1)
end

test "can be called as anonymous function" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10)
assert_all_close(fun.(0), 0.0)
assert_all_close(fun.(5), 1.0e-2)
end

test "can be called within JIT" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10)
assert_all_close(apply(jit(fun), [0]), 0.0)
assert_all_close(apply(jit(fun), [5]), 1.0e-2)
end

test "warmup phase increases linearly to peak value" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10)
assert_all_close(fun.(0), 0.0)
assert_all_close(fun.(2), 0.4 * 1.0e-2)
assert_all_close(fun.(4), 0.8 * 1.0e-2)
assert_all_close(fun.(5), 1.0e-2)
end

test "cosine decay phase follows cosine decay schedule after warmup" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10, alpha: 0.0)
assert_all_close(fun.(5), 1.0e-2)
assert_all_close(fun.(10), 0.5 * 1.0e-2)
assert_all_close(fun.(15), 0.0)
end

test "cosine decay phase respects alpha value" do
fun = warmup_cosine_decay(1.0e-2, warmup_steps: 5, decay_steps: 10, alpha: 0.5)
assert_all_close(fun.(5), 1.0e-2)
assert_all_close(fun.(10), 0.75 * 1.0e-2)
assert_all_close(fun.(15), 0.5 * 1.0e-2)
end

test "matches expected values at different step counts" do
fun = warmup_cosine_decay(1.0e-3, warmup_steps: 5, decay_steps: 10, alpha: 0.0)

assert_all_close(fun.(0), 0.0)
assert_all_close(fun.(5), 0.001)
assert_all_close(fun.(10), 0.0005)
assert_all_close(fun.(15), 0.0)
end
end

describe "constant" do
test "returns arity-1 function with defaults" do
fun = constant(1.0e-2)
Expand Down