-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! #536
Conversation
I wasn't sure where to put the test but here's an example: using Enzyme
using NNlib
w = randn(Float32, 3, 3, 5, 7);
dw = zero(w)
loss(w, x) = sum(conv(x, w))
x = randn(Float32, (3, 3, 5, 8));
Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));
@show dw |
Another note, right now there is almost certainly more copying of inputs than necessary. In the future we should add some sort of (type?) annotation to denote that some memory is immutable, or perhaps have a copy-on-write wrapper (as proposed by @gaurav-arya) |
Sure.
…On Sun, Sep 24, 2023 at 4:26 AM Carlo Lucibello ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In Project.toml
<#536 (comment)>:
> @@ -16,13 +17,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
We could depend just on EnzymeCore right?
—
Reply to this email directly, view it on GitHub
<#536 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXE3AG3G2VXUBRUNVR3X374F3ANCNFSM6AAAAAA5EXWQE4>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
f8d4717
to
f569b98
Compare
Even if EnzimeCore is a lean dependency I think it is better to put all under a NNlibEnzymeCoreExt extension. |
So I tried that originally, but the problem with that is then your cuda extension would itself need an enzymecore extension. If it matters, it also mirrors the way chainrulescore is used presently. |
this can be addressed having also a NNlibEnzymeCoreCUDA extension.
true but chainrulescore is AD framework agnostic and quite widespread, enzymecore is not. It is a very lean dependency, so if making it an extension proves to be too tricky I wouldn't oppose making it a hard one but I think we should avoid it if we can. |
@CarloLucibello my attempt to use extensions (current code) fails. However, without the extension the previous commit (f430a9b) succeeds. If you see how to fix the extension, we can use that, if not we can use the prior commit. |
src/conv.jl
Outdated
@@ -47,7 +47,7 @@ | |||
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors | |||
in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types. | |||
""" | |||
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} | |||
@inline function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better to factor out this change to another PR
Maybe this is a bit early, but I'd love to see benchmarks of this vs Zygote. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a tour de force of a PR, thanks @wsmoses!
Trying this out, hopefully to serve as a template for other functions.