From 8eecb10af8e5007c2fd64fb68ed5a97256d7123b Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Wed, 15 May 2024 14:39:22 +0200 Subject: [PATCH] Fixed derivative for tensor_transpose_mat_mul. --- src/GeometricMachineLearning.jl | 1 + src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl | 6 ++++-- src/kernels/tensor_transpose_mat_mul.jl | 2 +- test/custom_ad_rules/kernel_pullbacks.jl | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/GeometricMachineLearning.jl b/src/GeometricMachineLearning.jl index 42c4bd9f0..92a29f568 100644 --- a/src/GeometricMachineLearning.jl +++ b/src/GeometricMachineLearning.jl @@ -91,6 +91,7 @@ module GeometricMachineLearning include("kernels/kernel_ad_routines/tensor_mat_mul.jl") include("kernels/kernel_ad_routines/mat_tensor_mul.jl") include("kernels/kernel_ad_routines/tensor_tensor_mul.jl") + include("kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl") include("kernels/kernel_ad_routines/tensor_transpose_tensor_mul.jl") include("kernels/kernel_ad_routines/tensor_transpose.jl") include("kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl") diff --git a/src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl b/src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl index 12370c8ce..127ddb7c5 100644 --- a/src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl +++ b/src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl @@ -2,18 +2,20 @@ This implements the custom pullback for tensor_transpose_mat_mul """ function ChainRulesCore.rrule(::typeof(tensor_transpose_mat_mul), A::AbstractArray{T, 3}, B::AbstractMatrix{T}) where T - @assert axes(A, 2) == axes(B, 1) + @assert axes(A, 1) == axes(B, 1) C = tensor_transpose_mat_mul(A, B) function tensor_transpose_mat_mul_pullback(C_diff) f̄ = NoTangent() #tensor_transpose_mat_mul - A_diff = @thunk tensor_transpose_mat_mul(C_diff, B') + A_diff = @thunk mat_tensor_transpose_mul(B, C_diff) B_diff = @thunk sum(tensor_tensor_mul(A, C_diff), dims=3) return f̄, A_diff, B_diff end return C, tensor_transpose_mat_mul_pullback end +mat_tensor_transpose_mul(B, C) = mat_tensor_mul(B, tensor_transpose(C)) + tensor_transpose_mat_mul(A::Thunk, B::AbstractMatrix) = Thunk(() -> tensor_transpose_mat_mul(unthunk(A), B)) #function tensor_transpose_tensor_mul(A::AbstractArray{T, 3}, B::Thunk) where T diff --git a/src/kernels/tensor_transpose_mat_mul.jl b/src/kernels/tensor_transpose_mat_mul.jl index c3ead80ce..b753ea3c5 100644 --- a/src/kernels/tensor_transpose_mat_mul.jl +++ b/src/kernels/tensor_transpose_mat_mul.jl @@ -15,7 +15,7 @@ This implements the operation (A,B) -> A'*B for a tensor and a matrix end function tensor_transpose_mat_mul!(C, A, B) - @assert size(A)[1] == size(B)[1] + @assert axes(A, 1) == axes(B, 1) backend = KernelAbstractions.get_backend(A) kernel! = tensor_transpose_mat_mul_kernel!(backend) diff --git a/test/custom_ad_rules/kernel_pullbacks.jl b/test/custom_ad_rules/kernel_pullbacks.jl index ae4185a47..23371ffb2 100644 --- a/test/custom_ad_rules/kernel_pullbacks.jl +++ b/test/custom_ad_rules/kernel_pullbacks.jl @@ -1,5 +1,5 @@ using GeometricMachineLearning: lo_mat_mul, up_mat_mul, skew_mat_mul, symmetric_mat_mul, symmetric_mat_right_mul -using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_output_estimate, vec_tensor_mul, tensor_mat_skew_sym_assign +using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_output_estimate, vec_tensor_mul, tensor_mat_skew_sym_assign, tensor_transpose_mat_mul using ChainRulesTestUtils using Printf import Random @@ -14,6 +14,7 @@ function main(first_dim, second_dim, third_dim, third_tensor_dim) test_rrule(assign_q_and_p, rand(first_dim*2, second_dim), first_dim) test_rrule(assign_q_and_p, rand(first_dim*2, second_dim, third_dim), first_dim) test_rrule(tensor_mat_mul, rand(first_dim, second_dim, third_tensor_dim), rand(second_dim, third_dim)) + test_rrule(tensor_transpose_mat_mul, rand(second_dim, first_dim, third_tensor_dim), rand(second_dim, third_dim)) # test_rrule(tensor_transpose_mat_mul, rand(second_dim, first_dim, third_tensor_dim), rand(second_dim, third_dim)) test_rrule(mat_tensor_mul, rand(first_dim, second_dim), rand(second_dim, third_dim, third_tensor_dim)) test_rrule(tensor_tensor_mul, rand(first_dim, second_dim, third_tensor_dim), rand(second_dim, third_dim, third_tensor_dim))