Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme Testing + Caching in compute_gradients #640

Merged
merged 14 commits into from
May 15, 2024
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ env:
JULIA_AMDGPU_LOGGING_ENABLED: true
RETESTITEMS_TESTITEM_TIMEOUT: 10000
DATADEPS_ALWAYS_ACCEPT: true
JULIA_PKG_SERVER: ""
JULIA_NUM_THREADS: 8
GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988
SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w=="
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
env:
JULIA_PKG_SERVER: ""
- uses: julia-actions/julia-runtest@v1
env:
BACKEND_GROUP: "CPU"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ jobs:
with:
skip: Pkg,TOML
- uses: julia-actions/julia-buildpkg@v1
env:
JULIA_PKG_SERVER: ""
- uses: julia-actions/julia-runtest@v1
env:
BACKEND_GROUP: "CPU"
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.47"
version = "0.5.48"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -31,6 +31,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Expand All @@ -47,6 +48,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LuxComponentArraysExt = "ComponentArrays"
LuxDynamicExpressionsExt = "DynamicExpressions"
LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"]
LuxEnzymeExt = "Enzyme"
LuxFluxExt = "Flux"
LuxForwardDiffExt = "ForwardDiff"
LuxLuxAMDGPUExt = "LuxAMDGPU"
Expand Down Expand Up @@ -85,7 +87,7 @@ LuxAMDGPU = "0.2.2"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxDeviceUtils = "0.1.19"
LuxLib = "0.3.22"
LuxLib = "0.3.23"
LuxTestUtils = "0.1.15"
MLUtils = "0.4.3"
MPI = "0.20.19"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ basic building blocks which can be seamlessly composed to create complex trainin
Lux.Experimental.TrainState
Lux.Experimental.compute_gradients
Lux.Experimental.apply_gradients
Lux.Experimental.apply_gradients!
```

## Parameter Freezing
Expand Down
2 changes: 1 addition & 1 deletion examples/GravitationalWaveForm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ LuxCUDA = "0.2, 0.3"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2"
OrdinaryDiffEq = "6"
SciMLSensitivity = "7"
SciMLSensitivity = "7.57"
2 changes: 1 addition & 1 deletion examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function train()
y = y |> dev
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (data_idx, x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ MLUtils = "0.2, 0.3, 0.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
OrdinaryDiffEq = "6"
SciMLSensitivity = "7.45"
SciMLSensitivity = "7.57"
Statistics = "1"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ vjp_rule = AutoZygote()
function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(
grads, loss, stats, tstate = Lux.Experimental.compute_gradients(
vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
tstate = Lux.Training.apply_gradients(tstate, grads)
tstate = Lux.Experimental.apply_gradients!(tstate, grads)
end
return tstate
end
Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...)
for (x, y) in train_dataloader
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function main(model_type)

gs, loss, _, train_state = Lux.Experimental.compute_gradients(
AutoZygote(), compute_loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)

@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
Expand Down
4 changes: 2 additions & 2 deletions examples/SymbolicOptimalControl/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Optimization = "3.24.3"
OptimizationOptimJL = "0.2.3"
OptimizationOptimisers = "0.2.1"
OrdinaryDiffEq = "6.74.1"
SciMLSensitivity = "7.56.2"
Statistics = "1.11.1"
SciMLSensitivity = "7.57"
Statistics = "1.11"
SymbolicRegression = "0.24.1"
SymbolicUtils = "1.5.1"
76 changes: 76 additions & 0 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Lux: Lux

@concrete struct CachedEnzymeExtras
dparameters
objective_function
st_wrap
stats_wrap
end

# Case I: We have CachedEnzymeExtras and objective_function is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F}
dps = Lux.__recursive_make_zero!!(ts.cache.dparameters)

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
ts.cache.st_wrap, ts.states, ts, objective_function, dps,
ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap)

return dps, loss, ts.cache.stats_wrap, ts_new
end

# Case II: We have CachedEnzymeExtras and objective_function is changed.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F}
dps = Lux.__recursive_make_zero!!(ts.cache.dparameters)

obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(
objective_function, ts.states)

_, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap)

return dps, loss, stats_wrap, ts_new
end

# Case III: Nothing is cached. First call to `compute_gradients`
function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
dps = Lux.__recursive_make_zero(ts.parameters)
cache = CachedEnzymeExtras(dps, nothing, nothing, nothing)
ts_new = Lux.Experimental.TrainState(
cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new)
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
# storing the objective function.
function __construct_new_trainstate(
st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = CachedEnzymeExtras(dps, obj_fn, st_wrap, stats_wrap)
return Lux.Experimental.TrainState(
cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

function __construct_new_trainstate(

Check warning on line 68 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L68

Added line #L68 was not covered by tests
st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O,
dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = CachedEnzymeExtras(dps, nothing, nothing, nothing)
return Lux.Experimental.TrainState(

Check warning on line 72 in ext/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

end
11 changes: 9 additions & 2 deletions ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,20 @@ function Lux.Experimental.TrainState(
transform_variables::Union{Function, AbstractLuxDevice}=gpu_device())
ps, st = Lux.setup(rng, model) .|> transform_variables
st_opt = Optimisers.setup(optimizer, ps)
return Lux.Experimental.TrainState(model, ps, st, st_opt, 0)
return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0)
end

function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
end

function Lux.Experimental.apply_gradients!(ts::Lux.Experimental.TrainState, grads)
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(
ts.model, ps, ts.states, optimizer_state, ts.step + 1)
ts.cache, ts.objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step + 1)
end

# DistributedUtils
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_functio
loss.deriv = true
ReverseDiff.reverse_pass!(tape)
@set! ts.states = st
return grads, loss, stats, ts
return grads, ReverseDiff.value(loss), stats, ts
end

# AoS to SoA conversion
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F
Tracker.back!(loss)
@set! ts.states = st
grads = fmap(Tracker.grad, ps_tracked)
return grads, loss, stats, ts
return grads, Tracker.value(loss), stats, ts
end

# AoS to SoA conversion
Expand Down
72 changes: 69 additions & 3 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,34 @@
- `states`: Non-trainable Variables of the `model`.
- `optimizer_state`: Optimizer State.
- `step`: Number of updates of the parameters made.

Internal fields:

- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
"""
@concrete struct TrainState
@concrete struct TrainState{C, F}
cache::C
objective_function::F
model
parameters
states
optimizer_state
step::Int
end

function Base.show(io::IO, ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " parameters: ", Lux.parameterlength(ts.parameters))
println(io, " states: ", Lux.statelength(ts.states))
println(io, " optimizer_state: ", ts.optimizer_state)
print(io, " step: ", ts.step)
ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache)))
ts.objective_function !== nothing &&
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

"""
apply_gradients(ts::TrainState, grads)

Expand All @@ -26,13 +45,31 @@

- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
- `update_inplace`: Whether to update the parameters inplace or not.

## Returns

Updated [`TrainState`](@ref) object.
"""
function apply_gradients end

"""
apply_gradients!(ts::TrainState, grads)

Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version
of [`apply_gradients`](@ref).

## Arguments

- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.

## Returns

Updated [`TrainState`](@ref) object.
"""
function apply_gradients! end

"""
compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data,
ts::TrainState)
Expand All @@ -46,6 +83,7 @@
| `AutoZygote` | `Zygote.jl` |
| `AutoReverseDiff` | `ReverseDiff.jl` |
| `AutoTracker` | `Tracker.jl` |
| `AutoEnzyme` | `Enzyme.jl` |

## Arguments

Expand All @@ -65,6 +103,20 @@
- `loss`: Loss from the objective function.
- `stats`: Any computed statistics from the objective function.
- `ts`: Updated Training State.

## Special Notes on Backends

- `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call
to `compute_gradients` will be type-unstable. It is recommended to call this function
once outside of the training loop and use the returned train_state for type stability.
- `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled.

!!! danger

`grads` returned by this function might be aliased by the implementation of the gradient
backend. For example, if you cache the `grads` from step `i`, the new gradients
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
"""
function compute_gradients(ad::ADTypes.AbstractADType, ::F, _, ::TrainState) where {F}
return __maybe_implemented_compute_gradients(ad)
Expand All @@ -74,9 +126,23 @@
throw(ArgumentError(lazy"Support for AD backend $(nameof(T)) has not been implemented yet!!!"))
end

for package in (:Zygote, :Tracker, :ReverseDiff)
for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
adtype = Symbol(:Auto, package)
msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \
function!"
@eval function __maybe_implemented_compute_gradients(::ADTypes.$(adtype))
throw(ArgumentError(lazy"Load `$(package)` with `using $(package)`/`import $(package)` before using this function!"))
throw(ArgumentError($msg))

Check warning on line 134 in src/contrib/training.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/training.jl#L134

Added line #L134 was not covered by tests
end
end

@inline function __wrap_objective_function(objective_function::F, st) where {F}
st_updated, stats = st, (;)

# Boxing here is intentional
wrapped_objective_function = (model, ps, st, data) -> begin
y, st_updated, stats = objective_function(model, ps, st, data)
return y

Check warning on line 144 in src/contrib/training.jl

View check run for this annotation

Codecov / codecov/patch

src/contrib/training.jl#L144

Added line #L144 was not covered by tests
end

return wrapped_objective_function, st_updated, stats
end
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,22 @@

@inline __size(x::AbstractArray) = size(x)
@inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing

@inline __recursive_make_zero(x::Number) = zero(x)

Check warning on line 291 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L291

Added line #L291 was not covered by tests
@inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x)
@inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x)
@inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x)

Check warning on line 294 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L293-L294

Added lines #L293 - L294 were not covered by tests
@inline __recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero, values(x)))
@inline __recursive_make_zero(::Nothing) = nothing
@inline __recursive_make_zero(v::Val) = v
@inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x)

Check warning on line 299 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L297-L299

Added lines #L297 - L299 were not covered by tests

@inline __recursive_make_zero!!(x::Number) = zero(x)

Check warning on line 301 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L301

Added line #L301 was not covered by tests
@inline __recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline __recursive_make_zero!!(x::AbstractArray) = map(__recursive_make_zero!!, x)
@inline __recursive_make_zero!!(x::Tuple) = map(__recursive_make_zero!!, x)

Check warning on line 304 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L303-L304

Added lines #L303 - L304 were not covered by tests
@inline __recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero!!, values(x)))
@inline __recursive_make_zero!!(::Nothing) = nothing
@inline __recursive_make_zero!!(x) = fmap(__recursive_make_zero!!, x)

Check warning on line 308 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L307-L308

Added lines #L307 - L308 were not covered by tests
Loading
Loading