-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme Testing + Caching in compute_gradients
#640
Conversation
avik-pal
commented
May 12, 2024
•
edited
Loading
edited
- Add testing for normalization functions for Enzyme
- Caching in Training
- Inplace update for Optimisers
- Enzyme Training Utilities
- Rewrite the caching for Enzyme
- Needs some tests
- Train a simple enough MLP in the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 8bdde08 | Previous: 64ba96d | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3639.25 ns |
3633 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7171.75 ns |
7103.166666666667 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
21029 ns |
20759 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9728 ns |
9595.8 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
9046.75 ns |
8806 ns |
1.03 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4457.125 ns |
4427 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
1203.7539682539682 ns |
1206.5289256198348 ns |
1.00 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1116.9872611464968 ns |
1119.1708860759493 ns |
1.00 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1188.3161764705883 ns |
1198.7238805970148 ns |
0.99 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1795 ns |
1795.396551724138 ns |
1.00 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
178.88664812239222 ns |
178.75070028011206 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17423 ns |
17362 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
17092 ns |
17443 ns |
0.98 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
37139 ns |
37119 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
28333 ns |
28172 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19967 ns |
19957 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
16712 ns |
16821 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
4343.714285714285 ns |
4306.571428571428 ns |
1.01 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
3899.75 ns |
3846 ns |
1.01 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
4003.75 ns |
3947.25 ns |
1.01 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4996.428571428572 ns |
4824.714285714285 ns |
1.04 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1658 ns |
1656.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
38849589.5 ns |
38507163 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
57463704.5 ns |
57582497 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
76284433.5 ns |
75605722.5 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
88942467 ns |
88334752 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
72781539 ns |
72169093 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11978857 ns |
11692461 ns |
1.02 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
8337439 ns |
17394732.5 ns |
0.48 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
7019228 ns |
6995759 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
6969104 ns |
6978091 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
10049824 ns |
9930897 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6377212 ns |
6387632 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
713976867 ns |
693404244 ns |
1.03 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2834760052 ns |
2833937581 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
145850873 ns |
156241063 ns |
0.93 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
842710339 ns |
834588032 ns |
1.01 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
2574059116 ns |
2548330832 ns |
1.01 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
202158213 ns |
178313969 ns |
1.13 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
655838421 ns |
678237525.5 ns |
0.97 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2765113753 ns |
2822783216 ns |
0.98 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
131452603 ns |
120430437.5 ns |
1.09 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
172815028.5 ns |
175244675 ns |
0.99 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
645535092 ns |
651623284 ns |
0.99 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
34669482 ns |
45831746 ns |
0.76 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
165298740 ns |
165271861 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
641045050 ns |
639340636 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
30403729 ns |
30230425.5 ns |
1.01 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
187925737 ns |
186370118 ns |
1.01 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
717992556 ns |
708821435 ns |
1.01 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
37845546 ns |
35641646.5 ns |
1.06 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1276726758 ns |
1269088947 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1854655259 ns |
1867861740 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2195954604 ns |
1985497479 ns |
1.11 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2320413681 ns |
2378938010.5 ns |
0.98 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1888979051.5 ns |
1858802371.5 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
347336582 ns |
550456586.5 ns |
0.63 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
320737784 ns |
321344700 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
321591947.5 ns |
323696598.5 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
385026390 ns |
365630240.5 ns |
1.05 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11903504 ns |
11811649.5 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
17779926 ns |
17740770.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19142173 ns |
19084948 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23833763 ns |
23808418.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
17808550 ns |
17848657.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1153773 ns |
1164698.5 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
2501924.5 ns |
5670763 ns |
0.44 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2048820 ns |
2044411 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2027415 ns |
2030189 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2075285 ns |
2070069.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
197198 ns |
209019 ns |
0.94 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
293927 ns |
293681.5 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
267428 ns |
268068 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
370820 ns |
370370 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
411265.5 ns |
412428 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
275784 ns |
276023 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
408981 ns |
413690 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
84137 ns |
83986 ns |
1.00 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
82734 ns |
81692 ns |
1.01 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
84428 ns |
82072 ns |
1.03 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
87764 ns |
87042 ns |
1.01 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104554 ns |
104775 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
199995225 ns |
186685434.5 ns |
1.07 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
326496003.5 ns |
321228142.5 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
400994732 ns |
392606286 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
462902977 ns |
460952246 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
371403208 ns |
370351810 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
335761862 ns |
340754848 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
51360283.5 ns |
99666130 ns |
0.52 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
44195754 ns |
43812770.5 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
43921733.5 ns |
43647694 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
50184364.5 ns |
49549909.5 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28887089.5 ns |
28547342 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
19203788 ns |
19074425 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19668172.5 ns |
19527795.5 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23871830.5 ns |
23388473 ns |
1.02 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24346770.5 ns |
24094753 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19745369 ns |
19700860 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
6532224 ns |
6506694 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6544657 ns |
6506504 ns |
1.01 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6497849 ns |
6500898 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6497188 ns |
6496878.5 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
We should be caching the parameter gradients loss function compiled trace and such but this should be good initial version, we anyways need a redesign of the training API later on. |
a101eef
to
475a8cc
Compare
Need to wait for SciMLSensitivity SciML/SciMLSensitivity.jl#1046 before the doc build goes through |
Structural in what way
…On Sun, May 12, 2024 at 4:19 PM Avik Pal ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In ext/LuxEnzymeExt.jl
<#640 (comment)>:
> +using Setfield: @set!
+
+function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
+ ts::Lux.Experimental.TrainState) where {F}
+ dps = Enzyme.make_zero(ts.parameters)
+ fwd, rev = Enzyme.autodiff_thunk(
+ Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)},
+ Enzyme.Active, Enzyme.Const{typeof(ts.model)},
+ Enzyme.Duplicated{typeof(ts.parameters)},
+ Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)})
+ tape, (loss, st_new, stats), shadow_result = fwd(
+ Enzyme.Const(objective_function), Enzyme.Const(ts.model),
+ Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data))
+ rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model),
+ Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data),
+ (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape)
Is there a way to specify a structural zero instead of doing it like
Enzyme.make_zero(st_new)?
—
Reply to this email directly, view it on GitHub
<#640 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXB6MCCONUMM7PBSTIDZB72HZAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJRGQ4DCMJTGU>
.
You are receiving this because you commented.Message ID: <LuxDL/Lux.
***@***.***>
|
As in I want to say don't backpropagate wrt this value. For Zygote I would put a |
No activity annotations (eg to differentiate or not to differentiate) are
presently at an argument or return level
…On Sun, May 12, 2024 at 4:27 PM Avik Pal ***@***.***> wrote:
Structural in what way
As in I want to say don't backpropagate wrt this value. For Zygote I would
put a nothing
—
Reply to this email directly, view it on GitHub
<#640 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXFSTI3GNL2CW6TFE63ZB73GHAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDQNZYGA>
.
You are receiving this because you commented.Message ID: <LuxDL/Lux.
***@***.***>
|
So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two |
Honestly I would just pass in a function which first calls the first
function
…On Sun, May 12, 2024 at 4:31 PM Avik Pal ***@***.***> wrote:
So how would I annotate the return type? I am getting a tuple containing a
scalar, named tuple and an arbitrary object, we don't need to backpropagate
for the last two
—
Reply to this email directly, view it on GitHub
<#640 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXDT3BKPMVG27LUU7JDZB73T7AVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDSNZWGY>
.
You are receiving this because you commented.Message ID: <LuxDL/Lux.
***@***.***>
|
you mean something like function compute_gradients(........)
st_new_outer = Ref()
stats_outer = Ref()
function wrapper_function(args...)
y, st_new, stats = objective_function(args...)
st_new_outer[] = st_new
stats_outer[] = stats
return y
end
.....
end |
Yeah
…On Sun, May 12, 2024 at 5:00 PM Avik Pal ***@***.***> wrote:
you mean something like
function compute_gradients(........)
st_new_outer = Ref()
stats_outer = Ref()
function wrapper_function(args...)
y, st_new, stats = objective_function(args...)
st_new_outer[] = st_new
stats_outer[] = stats
return y
end
.....end
—
Reply to this email directly, view it on GitHub
<#640 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXCLJVABSUB5YSFF6TTZB77BXAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYTONZWHE>
.
You are receiving this because you commented.Message ID: <LuxDL/Lux.
***@***.***>
|
compute_gradients
using ADTypes, Lux, Random, Enzyme, Optimisers
model = Chain(Conv((3, 3), 3 => 6), GroupNorm(6, 3, gelu), Conv((3, 3), 6 => 32),
BatchNorm(32, gelu), GlobalMeanPool(), FlattenLayer(), Dense(32, 1))
x = rand(Float32, 32, 32, 3, 4);
tstate = Lux.Experimental.TrainState(Xoshiro(0), model, Adam(0.001f0));
function obj_fn(model, ps, st, x)
y, st_new = model(x, ps, st)
return sum(abs2, y), st_new, (;)
end
grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
AutoEnzyme(), obj_fn, x, tstate);
grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
AutoEnzyme(), obj_fn, x, tstate_new);
@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate);
# 14.726 ms (461 allocations: 9.75 MiB)
@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate_new);
# 14.233 ms (447 allocations: 9.74 MiB) Caching seems to work correctly. |
43a3bef
to
0f0559b
Compare
d21a9c0
to
7822166
Compare
Ok I did something wrong, it segfaulted the training test https://github.com/LuxDL/Lux.jl/actions/runs/9056705562/job/24879628489?pr=640#step:6:739 |
ae0064f
to
aef85f6
Compare
Locally things pass. Now we need to wait for SciMLSensitivity compats to be updated. |
0539e90
to
1f1de4c
Compare
1f1de4c
to
f0d2020
Compare
73ddde4
to
1c255cc
Compare
CI is not picking up on the latest SciMLSensitivity |
8e2723c
to
f2f97ef
Compare
70f6e5d
to
ecd2b3d
Compare
ecd2b3d
to
8bdde08
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #640 +/- ##
==========================================
- Coverage 87.56% 87.16% -0.40%
==========================================
Files 49 50 +1
Lines 2380 2439 +59
==========================================
+ Hits 2084 2126 +42
- Misses 296 313 +17 ☔ View full report in Codecov by Sentry. |