From 4364df8fa25f1a523670881e4ae0a7aefb7f5c1d Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Tue, 13 Feb 2024 11:42:05 +0100 Subject: [PATCH 1/3] use rrule of KwFunc for Core.kwcall --- src/stage1/generated.jl | 14 ++++--- test/gradcheck.jl | 81 ++++++++++++++++------------------------- 2 files changed, 40 insertions(+), 55 deletions(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index c1624046..66c34a2c 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -206,20 +206,20 @@ function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N} end # The static parameter on `f` disables the compileable_sig heuristic -function (::∂⃖{N})(f::T, args...) where {T, N} +function (::∂⃖{N})(f::T, args...; kwargs...) where {T, N} if N == 1 # Base case (inlined to avoid ambiguities with manually specified # higher order rules) - z = rrule(DiffractorRuleConfig(), f, args...) + z = rrule(DiffractorRuleConfig(), f, args...; kwargs...) if z === nothing - return ∂⃖recurse{1}()(f, args...) + return ∂⃖recurse{1}()(f, args...; kwargs...) end return z else ∂⃖p = ∂⃖{N-1}() - @destruct z, z̄ = ∂⃖p(rrule, f, args...) + @destruct z, z̄ = ∂⃖p(rrule, f, args...; kwargs...) if z === nothing - return ∂⃖recurse{N}()(f, args...) + return ∂⃖recurse{N}()(f, args...; kwargs...) else return ∂⃖rrule{N}()(z, z̄) end @@ -244,6 +244,10 @@ struct KwFunc{T,S} end (kw::KwFunc)(args...) = kw.kwf(args...) +function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...) + rrule(KwFunc(f), kwargs, f, args...) +end + function ChainRulesCore.rrule(::typeof(Core.kwfunc), f) KwFunc(f), Δ->(NoTangent(), Δ) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d003c82d..945722df 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -99,10 +99,11 @@ end @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradcheck(X -> sum(x -> x^2, X), randn(10)) + @test jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5)) + @test jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5)) - @test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) + # TODO: interesting that this is the only one that is not fixed @test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681 # Non-differentiable sum of booleans @@ -119,23 +120,15 @@ end @test gradcheck(x -> prod(x), (3,4)) @test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5)) + @test jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5)) end @testset "cumsum" begin @test jacobicheck(x -> cumsum(x), (4,)) - - # TypeError: in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(x -> cumsum(x, dims=2), (3,4,5)) - @test_broken jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> cumsum(x, dims=1), (3,)) - - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial + @test jacobicheck(x -> cumsum(x, dims=2), (3,4,5)) + @test jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial + @test jacobicheck(x -> cumsum(x, dims=1), (3,)) + @test jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial end @testset "getindex" begin @@ -221,8 +214,7 @@ end @test jacobicheck(x -> reverse(x), rand(17)) @test jacobicheck(x -> reverse(x, 8), rand(17)) @test jacobicheck(x -> reverse(x, 8, 13), rand(17)) - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken jacobicheck(x -> reverse(x, dims=2), rand(17, 42)) + @test jacobicheck(x -> reverse(x, dims=2), rand(17, 42)) end @testset "permutedims" begin @@ -237,11 +229,9 @@ end end @testset "repeat" begin - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> repeat(x; inner=2), rand(5)) - @test_broken jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5)) - @test_broken jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) - + @test jacobicheck(x -> repeat(x; inner=2), rand(5)) + @test jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5)) + @test jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) @test jacobicheck(x -> repeat(x, 3), rand(5)) @test jacobicheck(x -> repeat(x, 2, 3), rand(5)) @test jacobicheck(x -> repeat(x, 5), rand(5,7)) @@ -453,11 +443,10 @@ end @test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1 end for i = 1:3 - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1 - @test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1 - @test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1 - @test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1 + @test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1 + @test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1 + @test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1 + @test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1 end end @@ -473,27 +462,21 @@ end @testset "maximum" begin @test jacobicheck(maximum, rand(2, 3)) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) - @test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) - @test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) - + @test jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) + @test jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] end @testset "minimum" begin @test jacobicheck(minimum, rand(2, 3)) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) - @test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) + @test jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) end @testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72 - # TypeError: in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) - @test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) + @test jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) + @test jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) end @testset "vcat" begin @@ -544,20 +527,19 @@ end end @testset "cat(...; dims = $dim)" for dim in 1:3 - # Rewrite reached intrinsic function bitcast. Missing rule? catdim = (x...) -> cat(x..., dims = dim) - @test_broken jacobicheck(catdim, rand(4,1)) - @test_broken jacobicheck(catdim, rand(5), rand(5,1)) - @test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test jacobicheck(catdim, rand(4,1)) + @test jacobicheck(catdim, rand(5), rand(5,1)) + @test jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) catdimval = (x...) -> cat(x...; dims = Val(dim)) - @test_broken jacobicheck(catdimval, rand(5), rand(5)) - @test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1)) + @test jacobicheck(catdimval, rand(5), rand(5)) + @test jacobicheck(catdimval, rand(2,5), rand(2,5,1)) # one empty dim == 1 || continue - @test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) + @test jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) end @testset "one(s) and zero(s)" begin @@ -586,8 +568,7 @@ end # tests for https://github.com/FluxML/Zygote.jl/issues/724 x1 = rand(3, 3) @test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero - # MethodError: no method matching copy(::Nothing) - @test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) + @test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) # tests for un-broadcasting *, / via scalar rules @test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3)) @@ -620,7 +601,7 @@ end @test_broken jacobicheck(+, A, B, A) @test jacobicheck(-, A) # in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(-, A, B) + @test jacobicheck(-, A, B) end end From bf7bbe46ba32490ac57dbaa0dfe0b3490204666c Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 28 Feb 2024 15:58:26 +0100 Subject: [PATCH 2/3] split into kw/non-kw versions --- src/stage1/generated.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 66c34a2c..9f84bc88 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -206,13 +206,32 @@ function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N} end # The static parameter on `f` disables the compileable_sig heuristic -function (::∂⃖{N})(f::T, args...; kwargs...) where {T, N} +function (::∂⃖{N})(f::T, args...) where {T, N} if N == 1 # Base case (inlined to avoid ambiguities with manually specified # higher order rules) - z = rrule(DiffractorRuleConfig(), f, args...; kwargs...) + z = rrule(DiffractorRuleConfig(), f, args...) if z === nothing - return ∂⃖recurse{1}()(f, args...; kwargs...) + return ∂⃖recurse{1}()(f, args...) + end + return z + else + ∂⃖p = ∂⃖{N-1}() + @destruct z, z̄ = ∂⃖p(rrule, f, args...) + if z === nothing + return ∂⃖recurse{N}()(f, args...) + else + return ∂⃖rrule{N}()(z, z̄) + end + end +end +function (::∂⃖{N})(::typeof(Core.kwcall), kwargs, f::T, args...) where {T, N} + if N == 1 + # Base case (inlined to avoid ambiguities with manually specified + # higher order rules) + z = rrule(DiffractorRuleConfig(), KwFunc(f), kwargs, f, args...) + if z === nothing + return ∂⃖recurse{1}()(f, args..., kwargs...) end return z else From ba66870f11dedd79d1c8949bc6acecbf8dac2c24 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 28 Feb 2024 16:12:16 +0100 Subject: [PATCH 3/3] mark broadcast test as failing --- test/gradcheck.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 945722df..f3bb8e9a 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -556,7 +556,7 @@ end end @testset "broadcast" begin - @test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] + @test_broken gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] # mixing arrays & Ref(array) a = rand(3)