Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! #536

Merged
merged 26 commits into from
Sep 28, 2023

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented Sep 24, 2023

Trying this out, hopefully to serve as a template for other functions.

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 24, 2023

I wasn't sure where to put the test but here's an example:

using Enzyme
using NNlib

w = randn(Float32, 3, 3, 5, 7);
dw = zero(w)
loss(w, x) = sum(conv(x, w))
x = randn(Float32, (3, 3, 5, 8));

Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));

@show dw

@wsmoses wsmoses marked this pull request as ready for review September 24, 2023 07:53
@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 24, 2023

Another note, right now there is almost certainly more copying of inputs than necessary.

In the future we should add some sort of (type?) annotation to denote that some memory is immutable, or perhaps have a copy-on-write wrapper (as proposed by @gaurav-arya)

@wsmoses wsmoses mentioned this pull request Sep 24, 2023
Project.toml Outdated Show resolved Hide resolved
@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 24, 2023 via email

@wsmoses wsmoses changed the title Add EnzymeRule for conv! Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! Sep 25, 2023
@wsmoses wsmoses force-pushed the enzymeconv branch 3 times, most recently from f8d4717 to f569b98 Compare September 25, 2023 06:45
@CarloLucibello
Copy link
Member

Even if EnzimeCore is a lean dependency I think it is better to put all under a NNlibEnzymeCoreExt extension.

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 26, 2023

So I tried that originally, but the problem with that is then your cuda extension would itself need an enzymecore extension.

If it matters, it also mirrors the way chainrulescore is used presently.

@CarloLucibello
Copy link
Member

So I tried that originally, but the problem with that is then your cuda extension would itself need an enzymecore extension.

this can be addressed having also a NNlibEnzymeCoreCUDA extension.

If it matters, it also mirrors the way chainrulescore is used presently.

true but chainrulescore is AD framework agnostic and quite widespread, enzymecore is not. It is a very lean dependency, so if making it an extension proves to be too tricky I wouldn't oppose making it a hard one but I think we should avoid it if we can.

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 26, 2023

@CarloLucibello my attempt to use extensions (current code) fails. However, without the extension the previous commit (f430a9b) succeeds.

If you see how to fix the extension, we can use that, if not we can use the prior commit.

ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl Outdated Show resolved Hide resolved
src/NNlib.jl Outdated Show resolved Hide resolved
src/conv.jl Outdated
@@ -47,7 +47,7 @@
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types.
"""
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
@inline function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to factor out this change to another PR

src/enzyme.jl Outdated Show resolved Hide resolved
@maxfreu
Copy link
Contributor

maxfreu commented Sep 27, 2023

Maybe this is a bit early, but I'd love to see benchmarks of this vs Zygote.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a tour de force of a PR, thanks @wsmoses!

@ToucheSir ToucheSir merged commit aea063c into FluxML:master Sep 28, 2023
10 of 12 checks passed
@wsmoses wsmoses deleted the enzymeconv branch September 28, 2023 21:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants