Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rrule of KwFunc for Core.kwcall #270

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,25 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
end
end
end
function (::∂⃖{N})(::typeof(Core.kwcall), kwargs, f::T, args...) where {T, N}
if N == 1
# Base case (inlined to avoid ambiguities with manually specified
# higher order rules)
z = rrule(DiffractorRuleConfig(), KwFunc(f), kwargs, f, args...)
Comment on lines +228 to +232
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its basically a copy of the non-kw version of the function, but that is what we want in order to avoid ADing through the kw machinery, if there are no kws if I understand correctly?

if z === nothing
return ∂⃖recurse{1}()(f, args..., kwargs...)
end
return z
else
∂⃖p = ∂⃖{N-1}()
@destruct z, z̄ = ∂⃖p(rrule, f, args...; kwargs...)
if z === nothing
return ∂⃖recurse{N}()(f, args...; kwargs...)
else
return ∂⃖rrule{N}()(z, z̄)
end
end
end

function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
Tuple{Any, Any}(∂⃖{1}()(f, args...))
Expand All @@ -244,6 +263,10 @@ struct KwFunc{T,S}
end
(kw::KwFunc)(args...) = kw.kwf(args...)

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
rrule(KwFunc(f), kwargs, f, args...)
Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this

Suggested change
rrule(KwFunc(f), kwargs, f, args...)
rrule(f, args...; kwargs...)

is that the same, or is it different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure why the KwFunc struct is needed though.. it seems like could be done via rrule(::typeof(Core.kwcall), kwargs, f, args...) directly?

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the KwFunc and dispatching on kwcall directly seems to work, but I was afraid to remove something which I don't exactly understand

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
    r = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
    if r === nothing
        return nothing
    end
    x, back = r
    x, Δ->begin
        (NoTangent(), NoTangent(), back(Δ)...)
    end
end

Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this might be the thing that is there to avoid ADing through so much of the kwarg machinery in the nested AD case

end

function ChainRulesCore.rrule(::typeof(Core.kwfunc), f)
KwFunc(f), Δ->(NoTangent(), Δ)
end
Expand Down
83 changes: 32 additions & 51 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ end
@test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231
@test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10))
@test gradcheck(X -> sum(x -> x^2, X), randn(10))
@test jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5))
@test jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5))
@test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
# TODO: interesting that this is the only one that is not fixed
@test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681

# Non-differentiable sum of booleans
Expand All @@ -119,23 +120,15 @@ end

@test gradcheck(x -> prod(x), (3,4))
@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2)

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5))
@test jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5))
end

@testset "cumsum" begin
@test jacobicheck(x -> cumsum(x), (4,))

# TypeError: in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(x -> cumsum(x, dims=2), (3,4,5))
@test_broken jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> cumsum(x, dims=1), (3,))

# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial
@test jacobicheck(x -> cumsum(x, dims=2), (3,4,5))
@test jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial
@test jacobicheck(x -> cumsum(x, dims=1), (3,))
@test jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial
end

@testset "getindex" begin
Expand Down Expand Up @@ -221,8 +214,7 @@ end
@test jacobicheck(x -> reverse(x), rand(17))
@test jacobicheck(x -> reverse(x, 8), rand(17))
@test jacobicheck(x -> reverse(x, 8, 13), rand(17))
# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken jacobicheck(x -> reverse(x, dims=2), rand(17, 42))
@test jacobicheck(x -> reverse(x, dims=2), rand(17, 42))
end

@testset "permutedims" begin
Expand All @@ -237,11 +229,9 @@ end
end

@testset "repeat" begin
# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> repeat(x; inner=2), rand(5))
@test_broken jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5))
@test_broken jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))

@test jacobicheck(x -> repeat(x; inner=2), rand(5))
@test jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5))
@test jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test jacobicheck(x -> repeat(x, 3), rand(5))
@test jacobicheck(x -> repeat(x, 2, 3), rand(5))
@test jacobicheck(x -> repeat(x, 5), rand(5,7))
Expand Down Expand Up @@ -453,11 +443,10 @@ end
@test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1
end
for i = 1:3
# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
@test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
@test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
@test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
end
end

Expand All @@ -473,27 +462,21 @@ end

@testset "maximum" begin
@test jacobicheck(maximum, rand(2, 3))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3))
@test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4))
@test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))

@test jacobicheck(x -> maximum(x, dims=1), rand(2, 3))
@test jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4))
@test jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]
end

@testset "minimum" begin
@test jacobicheck(minimum, rand(2, 3))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3))
@test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3))
@test jacobicheck(x -> minimum(x, dims=1), rand(2, 3))
@test jacobicheck(x -> minimum(x, dims=2), rand(2, 3))
end

@testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72
# TypeError: in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2))
@test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3))
@test jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2))
@test jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3))
end

@testset "vcat" begin
Expand Down Expand Up @@ -544,20 +527,19 @@ end
end

@testset "cat(...; dims = $dim)" for dim in 1:3
# Rewrite reached intrinsic function bitcast. Missing rule?

catdim = (x...) -> cat(x..., dims = dim)
@test_broken jacobicheck(catdim, rand(4,1))
@test_broken jacobicheck(catdim, rand(5), rand(5,1))
@test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5))
@test jacobicheck(catdim, rand(4,1))
@test jacobicheck(catdim, rand(5), rand(5,1))
@test jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5))

catdimval = (x...) -> cat(x...; dims = Val(dim))
@test_broken jacobicheck(catdimval, rand(5), rand(5))
@test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1))
@test jacobicheck(catdimval, rand(5), rand(5))
@test jacobicheck(catdimval, rand(2,5), rand(2,5,1))

# one empty
dim == 1 || continue
@test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3))
@test jacobicheck(catdim, rand(0,5,3), rand(2,5,3))
end

@testset "one(s) and zero(s)" begin
Expand All @@ -574,7 +556,7 @@ end
end

@testset "broadcast" begin
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1]
@test_broken gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1]

# mixing arrays & Ref(array)
a = rand(3)
Expand All @@ -586,8 +568,7 @@ end
# tests for https://github.com/FluxML/Zygote.jl/issues/724
x1 = rand(3, 3)
@test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero
# MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))
@test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))

# tests for un-broadcasting *, / via scalar rules
@test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3))
Expand Down Expand Up @@ -620,7 +601,7 @@ end
@test_broken jacobicheck(+, A, B, A)
@test jacobicheck(-, A)
# in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(-, A, B)
@test jacobicheck(-, A, B)
end

end
Loading