Skip to content

Commit

Permalink
Minor correction to optimizer & add exp lr scheuler
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Nov 15, 2023
1 parent bc4a5c1 commit dd8fbdc
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/nn/adam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ function _step!(opt::Adam, θ::T, ∇::T, i; dispose::Bool) where T <: AbstractA
"instead: `$(size(θ))` vs `$(size(∇))`."))

# Debiasing.
lr = opt.lr * (1f0 - opt.β2^opt.current_step) / (1f0 - opt.β1^opt.current_step)
adam_step_kernel!(get_backend(opt))(
opt.μ[i], opt.ν[i], θ, ∇,
opt.lr, opt.β1, opt.β2, opt.ϵ; ndrange=length(θ))
opt.lr, opt.β1, opt.β2, opt.ϵ, opt.current_step; ndrange=length(θ))

dispose && KA.unsafe_free!(∇)

Expand All @@ -95,16 +94,29 @@ end

@kernel function adam_step_kernel!(
μ, ν, Θ, @Const(∇), lr::Float32,
β1::Float32, β2::Float32, ϵ::Float32,
β1::Float32, β2::Float32, ϵ::Float32, step::UInt32,
)
i = @index(Global)
@inbounds ∇ᵢ = ∇[i]
@inbounds ωᵢ = Θ[i]
∇ᵢ² = ∇ᵢ^2

∇ᵢ² = ∇ᵢ * ∇ᵢ
@inbounds μᵢ = μ[i] = β1 * μ[i] + (1f0 - β1) * ∇ᵢ
@inbounds νᵢ = ν[i] = β2 * ν[i] + (1f0 - β2) * ∇ᵢ²

@inbounds Θ[i] = ωᵢ - (lr * μᵢ) / (νᵢ + ϵ)
# Debiasing.
μ̂ = μᵢ / (1f0 - β1^step)
ν̂ = νᵢ / (1f0 - β2^step)

@inbounds ωᵢ = Θ[i]
@inbounds Θ[i] = ωᵢ - lr * μ̂ / (ν̂ + ϵ)
end

function exp_scheduler(lr_start::Float32, lr_end::Float32, steps::Int)
function _scheduler(step::Int)
(step < 0 || (lr_start 0f0 && lr_end 0f0)) && return 0f0

t = clamp(Float32(step / steps), 0f0, 1f0)
return exp(log(lr_start) * (1 - t) + log(lr_end) * t)
end
return _scheduler
end

0 comments on commit dd8fbdc

Please sign in to comment.