Skip to content

Commit

Permalink
Fixed derivative for tensor_transpose_mat_mul.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed May 15, 2024
1 parent d6b98b7 commit 8eecb10
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ module GeometricMachineLearning
include("kernels/kernel_ad_routines/tensor_mat_mul.jl")
include("kernels/kernel_ad_routines/mat_tensor_mul.jl")
include("kernels/kernel_ad_routines/tensor_tensor_mul.jl")
include("kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl")
include("kernels/kernel_ad_routines/tensor_transpose_tensor_mul.jl")
include("kernels/kernel_ad_routines/tensor_transpose.jl")
include("kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
This implements the custom pullback for tensor_transpose_mat_mul
"""
function ChainRulesCore.rrule(::typeof(tensor_transpose_mat_mul), A::AbstractArray{T, 3}, B::AbstractMatrix{T}) where T
@assert axes(A, 2) == axes(B, 1)
@assert axes(A, 1) == axes(B, 1)
C = tensor_transpose_mat_mul(A, B)
function tensor_transpose_mat_mul_pullback(C_diff)
= NoTangent()
#tensor_transpose_mat_mul
A_diff = @thunk tensor_transpose_mat_mul(C_diff, B')
A_diff = @thunk mat_tensor_transpose_mul(B, C_diff)
B_diff = @thunk sum(tensor_tensor_mul(A, C_diff), dims=3)
return f̄, A_diff, B_diff
end
return C, tensor_transpose_mat_mul_pullback
end

mat_tensor_transpose_mul(B, C) = mat_tensor_mul(B, tensor_transpose(C))

tensor_transpose_mat_mul(A::Thunk, B::AbstractMatrix) = Thunk(() -> tensor_transpose_mat_mul(unthunk(A), B))

Check warning on line 19 in src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_transpose_mat_mul.jl#L19

Added line #L19 was not covered by tests

#function tensor_transpose_tensor_mul(A::AbstractArray{T, 3}, B::Thunk) where T
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/tensor_transpose_mat_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This implements the operation (A,B) -> A'*B for a tensor and a matrix
end

function tensor_transpose_mat_mul!(C, A, B)
@assert size(A)[1] == size(B)[1]
@assert axes(A, 1) == axes(B, 1)

backend = KernelAbstractions.get_backend(A)
kernel! = tensor_transpose_mat_mul_kernel!(backend)
Expand Down
3 changes: 2 additions & 1 deletion test/custom_ad_rules/kernel_pullbacks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using GeometricMachineLearning: lo_mat_mul, up_mat_mul, skew_mat_mul, symmetric_mat_mul, symmetric_mat_right_mul
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_output_estimate, vec_tensor_mul, tensor_mat_skew_sym_assign
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_output_estimate, vec_tensor_mul, tensor_mat_skew_sym_assign, tensor_transpose_mat_mul
using ChainRulesTestUtils
using Printf
import Random
Expand All @@ -14,6 +14,7 @@ function main(first_dim, second_dim, third_dim, third_tensor_dim)
test_rrule(assign_q_and_p, rand(first_dim*2, second_dim), first_dim)
test_rrule(assign_q_and_p, rand(first_dim*2, second_dim, third_dim), first_dim)
test_rrule(tensor_mat_mul, rand(first_dim, second_dim, third_tensor_dim), rand(second_dim, third_dim))
test_rrule(tensor_transpose_mat_mul, rand(second_dim, first_dim, third_tensor_dim), rand(second_dim, third_dim))
# test_rrule(tensor_transpose_mat_mul, rand(second_dim, first_dim, third_tensor_dim), rand(second_dim, third_dim))
test_rrule(mat_tensor_mul, rand(first_dim, second_dim), rand(second_dim, third_dim, third_tensor_dim))
test_rrule(tensor_tensor_mul, rand(first_dim, second_dim, third_tensor_dim), rand(second_dim, third_dim, third_tensor_dim))
Expand Down

0 comments on commit 8eecb10

Please sign in to comment.