Skip to content

Commit

Permalink
Merge pull request #99 from JuliaGNI/constructor_tests_for_custom_arrays
Browse files Browse the repository at this point in the history
Constructor tests for custom arrays
  • Loading branch information
michakraus authored Dec 20, 2023
2 parents eb94a21 + 569df79 commit 45cece4
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/arrays/stiefel_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Outer constructor for `StiefelProjection`. This works with two integers as input
"""
StiefelProjection(N::Integer, n::Integer, T::Type=Float64) = StiefelProjection(CPU(), T, N, n)

StiefelProjection(T::Type, N::Integer, n::Integer) = StiefelProjection(N, n, T)

Base.size(E::StiefelProjection) = (E.N, E.n)
Base.getindex(E::StiefelProjection, i, j) = getindex(E.A, i, j)
Base.:+(E::StiefelProjection, A::AbstractMatrix) = E.A + A
Expand Down
20 changes: 12 additions & 8 deletions src/data_loader/tensor_assign.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,22 @@ end
end

@doc raw"""
The function `assign_output_estimate` is closely related to the transformer. It takes the last prediction_window columns of the output and uses is for the final prediction.
The function `assign_output_estimate` is closely related to the transformer. It takes the last `prediction_window` columns of the output and uses them for the final prediction.
i.e.
```math
\mathbb{R}^{N\times\mathtt{pw}}\to\mathbb{R}^{N\times\mathtt{pw}}, \begin{bmatrix} z^{(1)}_1 & \cdots z^{(T)}_1 \\
\cdots & \cdots \\
z^{(T - \mathtt{pw})}_n & \cdots & z^{(T})_n\end{bmatrix} \mapsto
\begin{bmatrix} z^{(1)}_1 & \cdots z^{(T)}_1 \\
\cdots & \cdots \\
z^{(T - \mathtt{pw})}_n & \cdots & z^{(T})_n\end{bmatrix}
\mathbb{R}^{N\times\mathtt{pw}}\to\mathbb{R}^{N\times\mathtt{pw}},
\begin{bmatrix}
z^{(1)}_1 & \cdots & z^{(T)}_1 \\
\cdots & \cdots & \cdots \\
z^{(1)}_n & \cdots & z^{(T})_n
\end{bmatrix} \mapsto
\begin{bmatrix}
z^{(T - \mathtt{pw})}_1 & \cdots & z^{(T)}_1 \\
\cdots & \cdots & \cdots \\
z^{(T - \mathtt{pw})}_n & \cdots & z^{(T})_n\end{bmatrix}
```
"""
function assign_output_estimate(full_output::AbstractArray{T, 3}, prediction_window) where T
function assign_output_estimate(full_output::AbstractArray{T, 3}, prediction_window::Int) where T
sys_dim, seq_length, batch_size = size(full_output)
backend = KernelAbstractions.get_backend(full_output)
output_estimate = KernelAbstractions.allocate(backend, T, sys_dim, prediction_window, batch_size)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/stiefel_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function Base.:*(Y::Adjoint{T, StiefelManifold{T, AT}}, B::AbstractMatrix) where
end

@doc raw"""
Computes the Riemannian gradient for the Stiefel manifold given an element ``Y\in{}St(N,n)`` and a matrix ``\nabla{}L\in\mahbb{R}^{N\times{}n}`` (the Euclidean gradient). It computes the Riemannian gradient with respect to the canonical metric (see the documentation for the function `metric` for an explanation of this).
Computes the Riemannian gradient for the Stiefel manifold given an element ``Y\in{}St(N,n)`` and a matrix ``\nabla{}L\in\mathbb{R}^{N\times{}n}`` (the Euclidean gradient). It computes the Riemannian gradient with respect to the canonical metric (see the documentation for the function `metric` for an explanation of this).
The precise form of the mapping is:
```math
\mathtt{rgrad}(Y, \nabla{}L) \mapsto \nabla{}L - Y(\nabla{}L)^TY
Expand Down
2 changes: 2 additions & 0 deletions test/arrays/array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function skew_mat_mul_test2(n, T=Float64)
AS2 = A*Matrix{T}(S)
@test isapprox(AS1, AS2)
end

# test Stiefel manifold projection test
function stiefel_proj_test(N,n)
In = I(n)
Expand All @@ -75,6 +76,7 @@ function stiefel_lie_alg_add_sub_test(N, n)
@test all(abs.(projection(W₁ - W₂) .- S₄) .< 1e-10)
end


function stiefel_lie_alg_vectorization_test(N, n; T=Float32)
A = rand(StiefelLieAlgHorMatrix{T}, N, n)
@test isapprox(StiefelLieAlgHorMatrix(vec(A), N, n), A)
Expand Down
37 changes: 37 additions & 0 deletions test/arrays/constructor_tests_for_custom_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using GeometricMachineLearning, Test
using LinearAlgebra: I

@doc raw"""
This tests various constructor for custom arrays, e.g. if calling `SymmetricMatrix` on a matrix ``A`` does
```math
A \mapsto \frac{1}{2}(A + A^T).
```
"""
function test_constructors_for_custom_arrays(n::Int, N::Int, T::Type)
A = rand(T, n, n)
B = rand(T, N, N)

# SymmetricMatrix
@test Matrix{T}(SymmetricMatrix(A)) T(.5) * (A + A')

# SkewSymMatrix
@test Matrix{T}(SkewSymMatrix(A)) T(.5) * (A - A')

# StiefelLieAlgHorMatrix
B_shor = StiefelLieAlgHorMatrix(SkewSymMatrix(B), n)
B_shor2 = Matrix{T}(SkewSymMatrix(B))
B_shor2[(n+1):N, (n+1):N] .= zero(T)
@test Matrix{T}(B_shor) B_shor2

# GrassmannLieAlgHorMatrix
B_ghor = GrassmannLieAlgHorMatrix(SkewSymMatrix(B), n)
B_ghor2 = copy(B_shor2)
B_ghor2[1:n, 1:n] .= zero(T)
@test Matrix{T}(B_ghor) B_ghor2

# StiefelProjection
E = StiefelProjection(T, N, n)
@test Matrix{T}(E) vcat(I(n), zeros(T, (N-n), n))
end

test_constructors_for_custom_arrays(5, 10, Float32)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using SafeTestsets
@safetestset "Addition tests for custom arrays " begin include("arrays/addition_tests_for_custom_arrays.jl") end
@safetestset "Scalar multiplication tests for custom arrays " begin include("arrays/scalar_multiplication_for_custom_arrays.jl") end
@safetestset "Matrix multiplication tests for custom arrays " begin include("arrays/matrix_multiplication_for_custom_arrays.jl") end
@safetestset "Test constructors for custom arrays " begin include("arrays/constructor_tests_for_custom_arrays.jl") end
@safetestset "Manifolds (Grassmann): " begin include("manifolds/grassmann_manifold.jl") end
@safetestset "Gradient Layer " begin include("layers/gradient_layer_tests.jl") end
@safetestset "Test symplecticity of upscaling layer " begin include("layers/sympnet_layers_test.jl") end
Expand Down

0 comments on commit 45cece4

Please sign in to comment.