From dd8fbdce96aa7e01da6fee932ef1e0f10651a47f Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 15 Nov 2023 17:03:37 +0200 Subject: [PATCH] Minor correction to optimizer & add exp lr scheuler --- src/nn/adam.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/nn/adam.jl b/src/nn/adam.jl index 18e9561..84e1d8c 100644 --- a/src/nn/adam.jl +++ b/src/nn/adam.jl @@ -83,10 +83,9 @@ function _step!(opt::Adam, θ::T, ∇::T, i; dispose::Bool) where T <: AbstractA "instead: `$(size(θ))` vs `$(size(∇))`.")) # Debiasing. - lr = opt.lr * √(1f0 - opt.β2^opt.current_step) / (1f0 - opt.β1^opt.current_step) adam_step_kernel!(get_backend(opt))( opt.μ[i], opt.ν[i], θ, ∇, - opt.lr, opt.β1, opt.β2, opt.ϵ; ndrange=length(θ)) + opt.lr, opt.β1, opt.β2, opt.ϵ, opt.current_step; ndrange=length(θ)) dispose && KA.unsafe_free!(∇) @@ -95,16 +94,29 @@ end @kernel function adam_step_kernel!( μ, ν, Θ, @Const(∇), lr::Float32, - β1::Float32, β2::Float32, ϵ::Float32, + β1::Float32, β2::Float32, ϵ::Float32, step::UInt32, ) i = @index(Global) @inbounds ∇ᵢ = ∇[i] - @inbounds ωᵢ = Θ[i] + ∇ᵢ² = ∇ᵢ^2 - ∇ᵢ² = ∇ᵢ * ∇ᵢ @inbounds μᵢ = μ[i] = β1 * μ[i] + (1f0 - β1) * ∇ᵢ @inbounds νᵢ = ν[i] = β2 * ν[i] + (1f0 - β2) * ∇ᵢ² - @inbounds Θ[i] = ωᵢ - (lr * μᵢ) / (√νᵢ + ϵ) + # Debiasing. + μ̂ = μᵢ / (1f0 - β1^step) + ν̂ = νᵢ / (1f0 - β2^step) + + @inbounds ωᵢ = Θ[i] + @inbounds Θ[i] = ωᵢ - lr * μ̂ / (√ν̂ + ϵ) end +function exp_scheduler(lr_start::Float32, lr_end::Float32, steps::Int) + function _scheduler(step::Int) + (step < 0 || (lr_start ≈ 0f0 && lr_end ≈ 0f0)) && return 0f0 + + t = clamp(Float32(step / steps), 0f0, 1f0) + return exp(log(lr_start) * (1 - t) + log(lr_end) * t) + end + return _scheduler +end