From 5d2664c9fe9f605b2ee1d9f798345e717f0c494f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 11 Apr 2024 16:35:28 -0700 Subject: [PATCH] fix test --- tests/optim_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/optim_test.py b/tests/optim_test.py index 9b0a7a0d4..318f87b33 100644 --- a/tests/optim_test.py +++ b/tests/optim_test.py @@ -6,7 +6,9 @@ def test_linear_with_warmup_scheduler(): initial_lr = 1.0 max_steps = 10_000 - scheduler = LinearWithWarmup(grad_clip_warmup_steps=None, grad_clip_warmup_factor=None, warmup_steps=2000) + scheduler = LinearWithWarmup( + grad_clip_warmup_steps=None, grad_clip_warmup_factor=None, warmup_steps=2000, warmup_min_lr=None + ) assert scheduler.get_lr(initial_lr, 0, max_steps) == 0.1 assert scheduler.get_lr(initial_lr, 2000, max_steps) == 1.0 assert scheduler.get_lr(initial_lr, 10_000, max_steps) == 0.1 @@ -18,7 +20,11 @@ def test_bolt_on_warmup_scheduler(): max_steps = 11_000 alpha_f = 0.1 scheduler = LinearWithWarmup( - grad_clip_warmup_steps=None, grad_clip_warmup_factor=None, warmup_steps=1000, alpha_f=alpha_f + grad_clip_warmup_steps=None, + grad_clip_warmup_factor=None, + warmup_steps=1000, + alpha_f=alpha_f, + warmup_min_lr=None, ) scheduler2 = BoltOnWarmupScheduler.wrap(scheduler, 5000, 6000) assert scheduler.get_lr(initial_lr, 100, max_steps) > 0.0