Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jan 17, 2025
1 parent 5a48e78 commit 3f10b81
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Test = "1"
TestExtras = "0.2,0.3"
TupleTools = "1.1"
VectorInterface = "0.4, 0.5"
Zygote = "0.7"
julia = "1.10"

[extras]
Expand All @@ -53,6 +54,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"]
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
2 changes: 2 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),

test_rrule(copy, T1)
test_rrule(copy, T2)
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))

test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
Expand Down
28 changes: 28 additions & 0 deletions test/bugfixes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,32 @@
@test storagetype(t5) == Vector{Float64}
tensorfree!(t2)
end

@testset "Issue #201" begin
function f(A::AbstractTensorMap)
U, S, V, = tsvd(A)
return tr(S)
end
function f(A::AbstractMatrix)
S = LinearAlgebra.svdvals(A)
return sum(S)
end
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
grad1, = Zygote.gradient(f, A₀)
grad2, = Zygote.gradient(f, convert(Array, A₀))
@test convert(Array, grad1) grad2

function g(A::AbstractTensorMap)
U, S, V, = tsvd(A)
return tr(U * V)
end
function g(A::AbstractMatrix)
U, S, V, = LinearAlgebra.svd(A)
return tr(U * V')
end
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
grad3, = Zygote.gradient(g, B₀)
grad4, = Zygote.gradient(g, convert(Array, B₀))
@test convert(Array, grad3) grad4
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Base.Iterators: take, product
# using SUNRepresentations: SUNIrrep
# const SU3Irrep = SUNIrrep{3}
using LinearAlgebra: LinearAlgebra
using Zygote: Zygote

const TK = TensorKit

Expand Down

0 comments on commit 3f10b81

Please sign in to comment.