Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Tensor-Matrix multiplication for SymmetricMatrix #148

Merged
merged 14 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/arrays/symmetric.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
@doc raw"""
A `SymmetricMatrix` ``A`` is a matrix ``A^T = A``.

If the constructor is called with a matrix as input it returns a symmetric matrix via the projection:
```math
A \mapsto \frac{1}{2}(A + A^T).
```
This is a projection defined via the canonical metric $(A,B) \mapsto \mathrm{tr}(A^TB)$.

Internally the `struct` saves a vector $S$ of size $n(n+1)\div2$. The conversion is done the following way:
Expand All @@ -14,10 +10,18 @@ Internally the `struct` saves a vector $S$ of size $n(n+1)\div2$. The conversion
```

So ``S`` stores a string of vectors taken from $A$: $S = [\tilde{a}_1, \tilde{a}_2, \ldots, \tilde{a}_n]$ with $\tilde{a}_i = [[A]_{i1},[A]_{i2},\ldots,[A]_{ii}]$.

### Constructor
If the constructor is called with a matrix as input it returns a symmetric matrix via the projection:
```math
A \mapsto \frac{1}{2}(A + A^T).
```

It can also be called with two arguments `S::AbstractVector` and `n::Integer` where `length(S) == n * (n + 1) ÷ 2` has to be true.
"""
mutable struct SymmetricMatrix{T, AT <: AbstractVector{T}} <: AbstractMatrix{T}
S::AT
n::Integer
n::Int

function SymmetricMatrix(S::AbstractVector, n::Integer)
@assert length(S) == n*(n+1)÷2
Expand Down Expand Up @@ -66,9 +70,10 @@ end

function Base.getindex(A::SymmetricMatrix,i::Int,j::Int)
if i ≥ j
return A.S[((i-1)*i)÷2+j]
A.S[((i-1)*i)÷2+j]
else
A.S[(j-1)*j÷2+i]
end
return A.S[(j-1)*j÷2+i]
end

Base.parent(A::SymmetricMatrix) = A.S
Expand Down
5 changes: 2 additions & 3 deletions src/kernels/kernel_ad_routines/mat_tensor_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ end
for i in axes(dC, 1)
sum_i = (i - 1) * i ÷ 2
if sum_i < l
for j in axes(dC, 1)
if l < (sum_i + i)
for j in axes(dC, 2)
if 1 ≤ (l - sum_i) < i
temp += A[l - sum_i, j, h] * dC[i, j, h]
temp += A[i, j, h] * dC[l - sum_i, j, h]
end
Expand Down Expand Up @@ -256,7 +256,6 @@ function ChainRulesCore.rrule(::typeof(mat_tensor_mul), B::SymmetricMatrix{T}, A
return C, symmetric_mul_pullback
end


############### Thunks

mat_tensor_mul(B::AbstractMatrix, A::Thunk) = Thunk(() -> mat_tensor_mul(B, unthunk(A)))
Expand Down
70 changes: 70 additions & 0 deletions src/kernels/kernel_ad_routines/tensor_mat_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,73 @@
function tensor_transpose_tensor_mul(A::AbstractArray{T, 3}, B::Thunk) where T
Thunk(() -> tensor_transpose_tensor_mul(A, unthunk(B)))
end

############### Symmetric (right mul)

@kernel function symmetric_right_da_kernel!(dA::AT, S::AbstractVector{T}, dC::AT) where {T, AT <: AbstractArray{T, 3}}
l, m, h = @index(Global, NTuple)

Check warning on line 28 in src/kernels/kernel_ad_routines/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_mat_mul.jl#L27-L28

Added lines #L27 - L28 were not covered by tests

temp = zero(T)

for j = 1:m
temp += S[(m - 1) * m ÷ 2 + j] * dC[l, j, h]
end
for j = (m+1):size(dA, 2)
temp += S[(j - 1) * j ÷ 2 + m] * dC[l, j, h]
end

dA[l, m, h] = temp

nothing
end

@kernel function symmetric_right_ds_kernel!(dS::AbstractMatrix{T}, B::AT, dC::AT) where {T, AT <: AbstractArray{T, 3}}
l, h = @index(Global, NTuple)

Check warning on line 45 in src/kernels/kernel_ad_routines/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_mat_mul.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
temp = zero(T)

for i in axes(dC, 1)
for k in axes(dC, 2)
sum_k = (k - 1) * k ÷ 2
temp += 1 ≤ l - sum_k ≤ k ? B[i, k, h] * dC[i, l - sum_k, h] : zero(T)
end
for j in axes(dC, 2)
sum_j = (j - 1) * j ÷ 2
temp += 1 ≤ l - sum_j < j ? B[i, l - sum_j, h] * dC[i, j, h] : zero(T)
end
end

dS[l, h] = temp

nothing
end

function ChainRulesCore.rrule(::typeof(symmetric_mat_right_mul), A::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T
C = symmetric_mat_right_mul(A, S, n)
function symmetric_mat_mul_pullback(dC::AbstractArray{T, 3})
backend = KernelAbstractions.get_backend(dC)
symmetric_right_da! = symmetric_right_da_kernel!(backend)
symmetric_right_ds! = symmetric_right_ds_kernel!(backend)

dA = zero(A)
dS = KernelAbstractions.zeros(backend, T, length(S), size(dC, 3))

symmetric_right_da!(dA, S, dC, ndrange = size(dA))
symmetric_right_ds!(dS, A, dC, ndrange = size(dS))

NoTangent(), dA, reshape(sum(dS, dims = 2), length(S)), NoTangent()
end

C, symmetric_mat_mul_pullback
end

function ChainRulesCore.rrule(::typeof(tensor_mat_mul), A::AbstractArray{T, 3}, B::SymmetricMatrix{T}) where T
@assert size(A, 2) == B.n
C = tensor_mat_mul(A, B)
function symmetric_right_mul_pullback(dC::AbstractArray{T, 3})
f̄, dA, dS, _ = rrule(symmetric_mat_right_mul, A, B.S, B.n)[2](dC)

Check warning on line 87 in src/kernels/kernel_ad_routines/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_mat_mul.jl#L83-L87

Added lines #L83 - L87 were not covered by tests

return f̄, dA, SymmetricMatrix(dS, B.n)

Check warning on line 89 in src/kernels/kernel_ad_routines/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_mat_mul.jl#L89

Added line #L89 was not covered by tests
end

return C, symmetric_mul_pullback

Check warning on line 92 in src/kernels/kernel_ad_routines/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kernel_ad_routines/tensor_mat_mul.jl#L92

Added line #L92 was not covered by tests
end
40 changes: 40 additions & 0 deletions src/kernels/tensor_mat_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,44 @@
C = KernelAbstractions.zeros(backend, T, tensor_shape...)
tensor_mat_mul!(C, A, B)
C
end

########################### SymmetricMatrix (right multiplication)

@kernel function symmetric_mat_right_mul_kernel!(C::AbstractArray{T, 3}, B::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T
i, j, l = @index(Global, NTuple)

Check warning on line 36 in src/kernels/tensor_mat_mul.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/tensor_mat_mul.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
tmp_sum = zero(T)

for k = j:n
tmp_sum += B[i, k, l] * S[(k - 1)* k ÷ 2 + j]
end

for k = 1:(j - 1)
tmp_sum += B[i, k, l] * S[(j - 1) * j ÷ 2 + k]
end

C[i, j, l] = tmp_sum
end

function symmetric_mat_right_mul!(C::AbstractArray{T, 3}, B::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T
backend = KernelAbstractions.get_backend(C)

symmetric_mat_right_mul_k! = symmetric_mat_right_mul_kernel!(backend)
symmetric_mat_right_mul_k!(C, B, S, n, ndrange = size(C))

nothing
end

function symmetric_mat_right_mul(B::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T
C = copy(B)

symmetric_mat_right_mul!(C, B, S, n)

C
end

function tensor_mat_mul!(C::AbstractArray{T, 3}, B::AbstractArray{T, 3}, A::SymmetricMatrix{T}) where T
@assert A.n == size(C, 2) == size(B, 2)

symmetric_mat_right_mul!(C, B, A.S, A.n)
end
5 changes: 3 additions & 2 deletions test/custom_ad_rules/kernel_pullbacks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GeometricMachineLearning: lo_mat_mul, up_mat_mul, skew_mat_mul, symmetric_mat_mul
using GeometricMachineLearning: lo_mat_mul, up_mat_mul, skew_mat_mul, symmetric_mat_mul, symmetric_mat_right_mul
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_output_estimate, vec_tensor_mul, tensor_mat_skew_sym_assign
using ChainRulesTestUtils
using Printf
Expand All @@ -25,7 +25,8 @@ function main(first_dim, second_dim, third_dim, third_tensor_dim)
test_rrule(lo_mat_mul, rand(first_dim * (first_dim - 1) ÷ 2), rand(first_dim, first_dim, third_dim), first_dim, check_thunked_output_tangent = false)
test_rrule(up_mat_mul, rand(first_dim * (first_dim - 1) ÷ 2), rand(first_dim, first_dim, third_dim), first_dim, check_thunked_output_tangent = false)
test_rrule(skew_mat_mul, rand(first_dim * (first_dim - 1) ÷ 2), rand(first_dim, first_dim, third_dim), first_dim, check_thunked_output_tangent = false)
test_rrule(symmetric_mat_mul, rand(first_dim * (first_dim + 1) ÷ 2), rand(first_dim, first_dim, third_dim), first_dim, check_thunked_output_tangent = false)
test_rrule(symmetric_mat_mul, rand(first_dim * (first_dim + 1) ÷ 2), rand(first_dim, second_dim, third_dim), first_dim, check_thunked_output_tangent = false)
test_rrule(symmetric_mat_right_mul, rand(second_dim, first_dim, third_dim), rand(first_dim * (first_dim + 1) ÷ 2), first_dim, check_thunked_output_tangent = false)
test_rrule(tensor_mat_skew_sym_assign, rand(first_dim, second_dim, third_tensor_dim), rand(first_dim, first_dim), check_thunked_output_tangent = false)
end

Expand Down
36 changes: 24 additions & 12 deletions test/kernels/tensor_mat_mul.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
using GeometricMachineLearning: tensor_mat_mul!
using GeometricMachineLearning
using GeometricMachineLearning: tensor_mat_mul!, tensor_mat_mul, allocate
import KernelAbstractions
using Random, Test

using CUDA
backend = CUDABackend()
Random.seed!(123)

dim1 = 256
dim2 = 123
dim3 = 45
num_data = 1000
backend = CPU()

const dim1 = 5
const dim2 = 12
const dim3 = 10
const num_data = 100

a = rand!(allocate(backend, Float32, dim1, dim2, num_data))
b = rand!(allocate(backend, Float32, dim2, dim3))
c = KernelAbstractions.zeros(backend, Float32, dim1, dim3, num_data)

@time tensor_mat_mul!(c,a,b)
KernelAbstractions.synchronize(backend)
tensor_mat_mul!(c, a, b)

c_manual = KernelAbstractions.zeros(backend, Float32, dim1, dim3, num_data)
@time for i in 1:num_data
c_manual[:,:,i] = a[:,:,i]*b
for i in 1:num_data
@views c_manual[:,:,i] = a[:,:,i] * b
end

@test isapprox(c, c_manual)

function test_tensor_multiplication(first_dim::Int=dim1, second_dim::Int=dim2, third_dim::Int=dim3; T = Float64)
A = rand(SymmetricMatrix{T}, second_dim)
B = rand(first_dim, second_dim, third_dim)
BA = tensor_mat_mul(B, A)
for l in 1:third_dim
@test (@view BA[:, :, l]) ≈ (@view B[:, :, l]) * A
end
end

test_tensor_multiplication()
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using SafeTestsets
@safetestset "Hamiltonian Neural Network " begin include("hamiltonian_neural_network_tests.jl") end
@safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end

@safetestset "Custom tensor matrix multiplication " begin include("kernels/tensor_mat_mul.jl") end
@safetestset "Custom inverse for 2x2, 3x3, 4x4, 5x5 matrices " begin include("kernels/tensor_inverse.jl") end
@safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end
@safetestset "ResNet " begin include("layers/resnet_tests.jl") end
Expand Down Expand Up @@ -72,4 +73,6 @@ using SafeTestsets
@safetestset "Batch functor(s) " begin include("batch/batch_functor.jl") end

@safetestset "Volume-Preserving Transformer (skew-symmetric tests) " begin include("volume_preserving_attention/test_skew_map.jl") end
@safetestset "Volume-Preserving Transformer (cayley-transform tests) " begin include("volume_preserving_attention/test_cayley_transforms.jl") end
@safetestset "Volume-Preserving Transformer (cayley-transform tests) " begin include("volume_preserving_attention/test_cayley_transforms.jl") end

# @safetestset "Linear Symplectic Attention " begin include("linear_symplectic_attention.jl") end
Loading