diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index 3aff38597c..35da5697ae 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -173,7 +173,7 @@ Flux.reset!(m) [m(x) for x in seq_init] ps = Flux.params(m) -opt= ADAM(1e-3) +opt= Adam(1e-3) Flux.train!(loss, ps, data, opt) ``` diff --git a/docs/src/saving.md b/docs/src/saving.md index 2ec6d94372..80332d4a1d 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -135,6 +135,6 @@ You can store the optimiser state alongside the model, to resume training exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimizers must be saved together to have the latter work after restoring. ```julia -opt = ADAM() +opt = Adam() @save "model-$(now()).bson" model opt ``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index e1fd1e9894..9455047836 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -39,7 +39,7 @@ for p in (W, b) end ``` -An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `ADAM`. +An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `Adam`. ## Optimiser Reference @@ -51,15 +51,15 @@ Descent Momentum Nesterov RMSProp -ADAM -RADAM +Adam +RAdam AdaMax -ADAGrad -ADADelta +AdaGrad +AdaDelta AMSGrad -NADAM -ADAMW -OADAM +NAdam +AdamW +OAdam AdaBelief ``` @@ -182,7 +182,7 @@ WeightDecay Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is ```julia -opt = Optimiser(ClipValue(1e-3), ADAM(1e-3)) +opt = Optimiser(ClipValue(1e-3), Adam(1e-3)) ``` ```@docs diff --git a/src/Flux.jl b/src/Flux.jl index a5eaec7ce4..0cacbd419a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -29,9 +29,9 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs using .Optimise: skip -export Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM, - ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay, +export Descent, Adam, Momentum, Nesterov, RMSProp, + AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, + AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm using CUDA diff --git a/src/deprecations.jl b/src/deprecations.jl index eb1f2fdcda..6719bd39e2 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -71,3 +71,12 @@ LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...) GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...) GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) + +# Optimisers with old naming convention +Base.@deprecate_binding ADAM Adam +Base.@deprecate_binding NADAM NAdam +Base.@deprecate_binding ADAMW AdamW +Base.@deprecate_binding RADAM RAdam +Base.@deprecate_binding OADAM OAdam +Base.@deprecate_binding ADAGrad AdaGrad +Base.@deprecate_binding ADADelta AdaDelta diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 010cbfc9bb..e691ce0170 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -4,8 +4,8 @@ using LinearAlgebra import ArrayInterface export train!, update!, - Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, + Descent, Adam, Momentum, Nesterov, RMSProp, + AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW,RAdam, OAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, ClipValue, ClipNorm diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ce78586fff..ce72a4b0ce 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -147,9 +147,9 @@ function apply!(o::RMSProp, x, Δ) end """ - ADAM(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) + Adam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) -[ADAM](https://arxiv.org/abs/1412.6980) optimiser. +[Adam](https://arxiv.org/abs/1412.6980) optimiser. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -159,21 +159,21 @@ end # Examples ```julia -opt = ADAM() +opt = Adam() -opt = ADAM(0.001, (0.9, 0.8)) +opt = Adam(0.001, (0.9, 0.8)) ``` """ -mutable struct ADAM <: AbstractOptimiser +mutable struct Adam <: AbstractOptimiser eta::Float64 beta::Tuple{Float64,Float64} epsilon::Float64 state::IdDict{Any, Any} end -ADAM(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = ADAM(η, β, ϵ, IdDict()) -ADAM(η::Real, β::Tuple, state::IdDict) = ADAM(η, β, EPS, state) +Adam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = Adam(η, β, ϵ, IdDict()) +Adam(η::Real, β::Tuple, state::IdDict) = Adam(η, β, EPS, state) -function apply!(o::ADAM, x, Δ) +function apply!(o::Adam, x, Δ) η, β = o.eta, o.beta mt, vt, βp = get!(o.state, x) do @@ -189,9 +189,9 @@ function apply!(o::ADAM, x, Δ) end """ - RADAM(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) + RAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) -[Rectified ADAM](https://arxiv.org/abs/1908.03265) optimizer. +[Rectified Adam](https://arxiv.org/abs/1908.03265) optimizer. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -201,21 +201,21 @@ end # Examples ```julia -opt = RADAM() +opt = RAdam() -opt = RADAM(0.001, (0.9, 0.8)) +opt = RAdam(0.001, (0.9, 0.8)) ``` """ -mutable struct RADAM <: AbstractOptimiser +mutable struct RAdam <: AbstractOptimiser eta::Float64 beta::Tuple{Float64,Float64} epsilon::Float64 state::IdDict{Any, Any} end -RADAM(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = RADAM(η, β, ϵ, IdDict()) -RADAM(η::Real, β::Tuple, state::IdDict) = RADAM(η, β, EPS, state) +RAdam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = RAdam(η, β, ϵ, IdDict()) +RAdam(η::Real, β::Tuple, state::IdDict) = RAdam(η, β, EPS, state) -function apply!(o::RADAM, x, Δ) +function apply!(o::RAdam, x, Δ) η, β = o.eta, o.beta ρ∞ = 2/(1-β[2])-1 @@ -241,7 +241,7 @@ end """ AdaMax(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) -[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of ADAM based on the ∞-norm. +[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -281,10 +281,10 @@ function apply!(o::AdaMax, x, Δ) end """ - OADAM(η = 0.0001, β::Tuple = (0.5, 0.9), ϵ = $EPS) + OAdam(η = 0.0001, β::Tuple = (0.5, 0.9), ϵ = $EPS) -[OADAM](https://arxiv.org/abs/1711.00141) (Optimistic ADAM) -is a variant of ADAM adding an "optimistic" term suitable for adversarial training. +[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam) +is a variant of Adam adding an "optimistic" term suitable for adversarial training. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -294,21 +294,21 @@ is a variant of ADAM adding an "optimistic" term suitable for adversarial traini # Examples ```julia -opt = OADAM() +opt = OAdam() -opt = OADAM(0.001, (0.9, 0.995)) +opt = OAdam(0.001, (0.9, 0.995)) ``` """ -mutable struct OADAM <: AbstractOptimiser +mutable struct OAdam <: AbstractOptimiser eta::Float64 beta::Tuple{Float64,Float64} epsilon::Float64 state::IdDict{Any, Any} end -OADAM(η::Real = 0.001, β::Tuple = (0.5, 0.9), ϵ::Real = EPS) = OADAM(η, β, ϵ, IdDict()) -OADAM(η::Real, β::Tuple, state::IdDict) = RMSProp(η, β, EPS, state) +OAdam(η::Real = 0.001, β::Tuple = (0.5, 0.9), ϵ::Real = EPS) = OAdam(η, β, ϵ, IdDict()) +OAdam(η::Real, β::Tuple, state::IdDict) = RMSProp(η, β, EPS, state) -function apply!(o::OADAM, x, Δ) +function apply!(o::OAdam, x, Δ) η, β = o.eta, o.beta mt, vt, Δ_, βp = get!(o.state, x) do @@ -326,9 +326,9 @@ function apply!(o::OADAM, x, Δ) end """ - ADAGrad(η = 0.1, ϵ = $EPS) + AdaGrad(η = 0.1, ϵ = $EPS) -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has +[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has parameter specific learning rates based on how frequently it is updated. Parameters don't need tuning. @@ -338,20 +338,20 @@ Parameters don't need tuning. # Examples ```julia -opt = ADAGrad() +opt = AdaGrad() -opt = ADAGrad(0.001) +opt = AdaGrad(0.001) ``` """ -mutable struct ADAGrad <: AbstractOptimiser +mutable struct AdaGrad <: AbstractOptimiser eta::Float64 epsilon::Float64 acc::IdDict end -ADAGrad(η::Real = 0.1, ϵ::Real = EPS) = ADAGrad(η, ϵ, IdDict()) -ADAGrad(η::Real, state::IdDict) = ADAGrad(η, EPS, state) +AdaGrad(η::Real = 0.1, ϵ::Real = EPS) = AdaGrad(η, ϵ, IdDict()) +AdaGrad(η::Real, state::IdDict) = AdaGrad(η, EPS, state) -function apply!(o::ADAGrad, x, Δ) +function apply!(o::AdaGrad, x, Δ) η = o.eta acc = get!(() -> fill!(similar(x), o.epsilon), o.acc, x)::typeof(x) @. acc += Δ * conj(Δ) @@ -359,9 +359,9 @@ function apply!(o::ADAGrad, x, Δ) end """ - ADADelta(ρ = 0.9, ϵ = $EPS) + AdaDelta(ρ = 0.9, ϵ = $EPS) -[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning +[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning rate based on a window of past gradient updates. Parameters don't need tuning. @@ -370,20 +370,20 @@ Parameters don't need tuning. # Examples ```julia -opt = ADADelta() +opt = AdaDelta() -opt = ADADelta(0.89) +opt = AdaDelta(0.89) ``` """ -mutable struct ADADelta <: AbstractOptimiser +mutable struct AdaDelta <: AbstractOptimiser rho::Float64 epsilon::Float64 state::IdDict{Any, Any} end -ADADelta(ρ::Real = 0.9, ϵ::Real = EPS) = ADADelta(ρ, ϵ, IdDict()) -ADADelta(ρ::Real, state::IdDict) = ADADelta(ρ, EPS, state) +AdaDelta(ρ::Real = 0.9, ϵ::Real = EPS) = AdaDelta(ρ, ϵ, IdDict()) +AdaDelta(ρ::Real, state::IdDict) = AdaDelta(ρ, EPS, state) -function apply!(o::ADADelta, x, Δ) +function apply!(o::AdaDelta, x, Δ) ρ = o.rho acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} @. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) @@ -397,7 +397,7 @@ end """ AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) -The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM +The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam optimiser. Parameters don't need tuning. # Parameters @@ -436,9 +436,9 @@ function apply!(o::AMSGrad, x, Δ) end """ - NADAM(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) + NAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) -[NADAM](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of ADAM. +[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam. Parameters don't need tuning. # Parameters @@ -449,21 +449,21 @@ Parameters don't need tuning. # Examples ```julia -opt = NADAM() +opt = NAdam() -opt = NADAM(0.002, (0.89, 0.995)) +opt = NAdam(0.002, (0.89, 0.995)) ``` """ -mutable struct NADAM <: AbstractOptimiser +mutable struct NAdam <: AbstractOptimiser eta::Float64 beta::Tuple{Float64, Float64} epsilon::Float64 state::IdDict{Any, Any} end -NADAM(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = NADAM(η, β, ϵ, IdDict()) -NADAM(η::Real, β::Tuple, state::IdDict) = NADAM(η, β, EPS, state) +NAdam(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = NAdam(η, β, ϵ, IdDict()) +NAdam(η::Real, β::Tuple, state::IdDict) = NAdam(η, β, EPS, state) -function apply!(o::NADAM, x, Δ) +function apply!(o::NAdam, x, Δ) η, β = o.eta, o.beta mt, vt, βp = get!(o.state, x) do @@ -480,9 +480,9 @@ function apply!(o::NADAM, x, Δ) end """ - ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) + AdamW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) -[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its +[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its weight decay regularization. # Parameters @@ -494,19 +494,19 @@ weight decay regularization. # Examples ```julia -opt = ADAMW() +opt = AdamW() -opt = ADAMW(0.001, (0.89, 0.995), 0.1) +opt = AdamW(0.001, (0.89, 0.995), 0.1) ``` """ -ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(ADAM(η, β), WeightDecay(decay)) +AdamW(η = 0.001, β = (0.9, 0.999), decay = 0) = + Optimiser(Adam(η, β), WeightDecay(decay)) """ AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known -ADAM optimiser. +Adam optimiser. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -537,7 +537,7 @@ function apply!(o::AdaBelief, x, Δ) (zero(x), zero(x), Float64[β[1], β[2]]) end :: Tuple{typeof(x), typeof(x), Vector{Float64}} - #= st is a variance and can go to zero. This is in contrast to ADAM, which uses the + #= st is a variance and can go to zero. This is in contrast to Adam, which uses the second moment which is usually far enough from zero. This is problematic, since st can be slightly negative due to numerical error, and the square root below will fail. Also, if we want to differentiate through the optimizer, √0 is not differentiable. @@ -643,10 +643,10 @@ for more general scheduling techniques. `ExpDecay` is typically composed with other optimizers as the last transformation of the gradient: ```julia -opt = Optimiser(ADAM(), ExpDecay(1.0)) +opt = Optimiser(Adam(), ExpDecay(1.0)) ``` Note: you may want to start with `η=1` in `ExpDecay` when combined with other -optimizers (`ADAM` in this case) that have their own learning rate. +optimizers (`Adam` in this case) that have their own learning rate. """ mutable struct ExpDecay <: AbstractOptimiser eta::Float64 @@ -681,7 +681,7 @@ with coefficient ``λ`` to the loss. # Examples ```julia -opt = Optimiser(WeightDecay(1f-4), ADAM()) +opt = Optimiser(WeightDecay(1f-4), Adam()) ``` """ mutable struct WeightDecay <: AbstractOptimiser diff --git a/test/optimise.jl b/test/optimise.jl index 9c358a6825..e922d3c0b8 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -10,8 +10,8 @@ using Random # so that w and w' are different Random.seed!(84) w = randn(10, 10) - @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), - NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), AdaBelief(), + @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), Nesterov(), RMSProp(), Momentum()] Random.seed!(42) w′ = randn(10, 10) @@ -34,7 +34,7 @@ end Random.seed!(42) w′ = randn(10, 10) loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), ADAM(0.001)) + opt = Optimiser(Opt(), Adam(0.001)) for t = 1:10^5 θ = Params([w′]) x = rand(10) @@ -202,7 +202,7 @@ end end # Flux PR #1776 -# We need to test that optimisers like ADAM that maintain an internal momentum +# We need to test that optimisers like Adam that maintain an internal momentum # estimate properly calculate the second-order statistics on the gradients as # the flow backward through the model. Previously, we would calculate second- # order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which @@ -210,7 +210,7 @@ end # a simple optimization is montonically decreasing (up to learning step effects) @testset "Momentum Optimisers and complex values" begin # Test every optimizer that has momentum internally - for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief] + for opt_ctor in [Adam, RMSProp, RAdam, OAdam, AdaGrad, AdaDelta, NAdam, AdaBelief] # Our "model" is just a complex number w = zeros(ComplexF32, 1)