From 3a9a4244ed2c87bde5d8456532755d93620c4c79 Mon Sep 17 00:00:00 2001 From: dussong Date: Fri, 18 Oct 2024 13:36:40 +0200 Subject: [PATCH 1/2] test cg and RI functions --- Project.toml | 14 ++- src/O3_alternative.jl | 92 ++++++++++++++++ src/RepLieGroups.jl | 2 + src/obsolete/O3.jl | 17 +++ test/runtests.jl | 5 +- test/test_RI_basis.jl | 125 +++++++++++++++++++++ test/test_RPI_basis.jl | 141 ++++++++++++++++++++++++ test/test_cg_vs_partialwavefunctions.jl | 18 +++ 8 files changed, 407 insertions(+), 7 deletions(-) create mode 100644 src/O3_alternative.jl create mode 100644 test/test_RI_basis.jl create mode 100644 test/test_RPI_basis.jl create mode 100644 test/test_cg_vs_partialwavefunctions.jl diff --git a/Project.toml b/Project.toml index e04ffc1..6eef21a 100644 --- a/Project.toml +++ b/Project.toml @@ -4,20 +4,24 @@ authors = ["Christoph Ortner and contributors"] version = "0.0.3" [deps] +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PartialWaveFunctions = "793d2195-304b-438e-bbb1-bc33c872ac39" +Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -julia = "1" StaticArrays = "1.5" +julia = "1" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" -WignerD = "87c4ff3e-34df-11e9-37a7-516cea4e0402" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" -BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WignerD = "87c4ff3e-34df-11e9-37a7-516cea4e0402" [targets] -test = ["Test", "Polynomials4ML", "WignerD", "Rotations","BlockDiagonals"] +test = ["Test", "Polynomials4ML", "WignerD", "Rotations", "BlockDiagonals"] diff --git a/src/O3_alternative.jl b/src/O3_alternative.jl new file mode 100644 index 0000000..37d10b0 --- /dev/null +++ b/src/O3_alternative.jl @@ -0,0 +1,92 @@ +# Alternative to the computation of rotation equivariant coupling coefficients + +using PartialWaveFunctions +using Combinatorics +using LinearAlgebra + +export re_basis_new + +function CG(l,m,L,N) + M=m[1]+m[2] + if L[2] 0 +# return im # (-1)^m/sqrt(2) +# elseif m > 0 && μ < 0 +# return (-1)^m # - im * (-1)^m/sqrt(2) +# else +# return 1. # im/sqrt(2) +# end +# end + Ctran(l::Int64) = sparse(Matrix{ComplexF64}([ Ctran(l,m,μ) for m = -l:l, μ = -l:l ])) |> dropzeros ## NOTE: Ctran(L) is the transformation matrix from rSH to cSH. More specifically, diff --git a/test/runtests.jl b/test/runtests.jl index 07ca1e4..50b5496 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Test @testset "RepLieGroups.jl" begin # Write your tests here. @testset "SYYVector" begin include("test_yyvector.jl"); end - @testset "O3-ClebschGordan" begin include("test_obsolete_cg.jl"); end - @testset "O3" begin include("test_obsolete_o3.jl"); end + # @testset "O3-ClebschGordan" begin include("test_obsolete_cg.jl"); end + # @testset "O3" begin include("test_obsolete_o3.jl"); end + @testset "CGcoef vs PartialWaveFunctions" begin include("test_cg_vs_partialwavefunctions.jl"); end end diff --git a/test/test_RI_basis.jl b/test/test_RI_basis.jl new file mode 100644 index 0000000..1f13b77 --- /dev/null +++ b/test/test_RI_basis.jl @@ -0,0 +1,125 @@ +using SpheriCart, StaticArrays, LinearAlgebra, RepLieGroups, WignerD, + Combinatorics +using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real +O3 = RepLieGroups.O3 +using Test + +# Evaluation of spherical harmonics +function eval_cY(rbasis::SphericalHarmonics{LMAX}, 𝐫) where {LMAX} + Yr = rbasis(𝐫) + Yc = zeros(Complex{eltype(Yr)}, length(Yr)) + for l = 0:LMAX + # m = 0 + i_l0 = SpheriCart.lm2idx(l, 0) + Yc[i_l0] = Yr[i_l0] + # m ≠ 0 + for m = 1:l + i_lm⁺ = SpheriCart.lm2idx(l, m) + i_lm⁻ = SpheriCart.lm2idx(l, -m) + Ylm⁺ = Yr[i_lm⁺] + Ylm⁻ = Yr[i_lm⁻] + Yc[i_lm⁺] = (-1)^m * (Ylm⁺ + im * Ylm⁻) / sqrt(2) + Yc[i_lm⁻] = (Ylm⁺ - im * Ylm⁻) / sqrt(2) + end + end + return Yc + end + + function rand_sphere() + u = @SVector randn(3) + return u / norm(u) + end + + function rand_rot() + K = @SMatrix randn(3,3) + return exp(K - K') + end + + function f(Rs, q; coeffs=coeffs, MM=MM, ll=ll) + Lmax = maximum(ll) + real_basis = SphericalHarmonics(Lmax) + YY = [] + for i in 1:length(ll) + push!(YY, eval_cY(real_basis, Rs[i])) + end + out = zero(eltype(YY[1])) + for (c, mm) in zip(coeffs[q, :], MM) + ind = Int[] + for i in 1:length(ll) + push!(ind, SpheriCart.lm2idx(ll[i], mm[i])) + end + out += c * prod(YY[i][ind[i]] for i in 1:length(ll)) + end + return out +end + + +# for the moment the code with generalized CG only works with L=0 +L = 0 +cc = Rot3DCoeffs(L) +ll = SA[2,2,2,3,3] + +# version with svd +@time coeffs1, MM1 = O3.re_basis(cc, ll) +nbas = size(coeffs1, 1) + +#version with gen CG coefficients +@time coeffs2, MM2 = re_basis_new(ll) + +# simple test on size +@test size(coeffs1) == size(coeffs2) +@test size(MM1) == size(MM2) + +P1 = sortperm(MM1) +P2 = sortperm(MM2) +MMsorted1 = MM1[P1] +MMsorted2 = MM2[P2] +# check that same mm values +@test MMsorted1 == MMsorted2 + +coeffsp1 = coeffs1[:,P1] +coeffsp2 = coeffs2[:,P2] + +# test that full rank +@test rank(coeffsp1) == size(coeffsp1,1) +@test rank(coeffsp2) == size(coeffsp2,1) + +# check that the coef span the same space - test fails +@test nbas == rank([coeffsp; coeffsp2], rtol = 1e-12) + + +Rs = [rand_sphere() for _ in 1:length(ll)] +Q = rand_rot() +QRs = [Q*Rs[i] for i in 1:length(Rs)] +fRs1 = [ f(Rs, q; coeffs=coeffs1, MM=MM1, ll=ll) for q = 1:nbas ] +fRs1Q = [ f(QRs, q; coeffs=coeffs1, MM=MM1, ll=ll) for q = 1:nbas ] + +# check invariance (for now) +@test norm(fRs1 .- fRs1Q) < 1e-12 + +fRs2 = [ f(Rs, q; coeffs=coeffs2, MM=MM2, ll=ll) for q = 1:nbas ] +fRs2Q = [ f(QRs, q; coeffs=coeffs2, MM=MM2, ll=ll) for q = 1:nbas ] + +# check invariance (for now) +@test norm(fRs2 .- fRs2Q) < 1e-12 + +# Test on batch +ntest = 1000 +A1 = zeros(nbas, ntest) +A2 = zeros(nbas, ntest) +for i = 1:ntest + Rs = [rand_sphere() for _ in 1:length(ll)] + for q = 1:nbas + fRs = f(Rs, q; coeffs = coeffs1, MM=MM1, ll=ll) + @assert abs.(imag(fRs)) < 1e-16 + A1[q, i] = real(fRs) + + fRs2 = f(Rs, q; coeffs = coeffs2, MM=MM2, ll=ll) + @assert abs.(imag(fRs2)) < 1e-16 + A2[q, i] = real(fRs2) + end +end + +# check that functions span same space +rk = rank([A1;A2]; rtol = 1e-12) +@test rk == nbas \ No newline at end of file diff --git a/test/test_RPI_basis.jl b/test/test_RPI_basis.jl new file mode 100644 index 0000000..6ac48fc --- /dev/null +++ b/test/test_RPI_basis.jl @@ -0,0 +1,141 @@ + +using SpheriCart, StaticArrays, LinearAlgebra, RepLieGroups, WignerD, + Combinatorics +using Rotations +using RepLieGroups.O3: Rot3DCoeffs +O3 = RepLieGroups.O3 + +function eval_cY(rbasis::SphericalHarmonics{LMAX}, 𝐫) where {LMAX} + Yr = rbasis(𝐫) + Yc = zeros(Complex{eltype(Yr)}, length(Yr)) + for l = 0:LMAX + # m = 0 + i_l0 = SpheriCart.lm2idx(l, 0) + Yc[i_l0] = Yr[i_l0] + # m ≠ 0 + for m = 1:l + i_lm⁺ = SpheriCart.lm2idx(l, m) + i_lm⁻ = SpheriCart.lm2idx(l, -m) + Ylm⁺ = Yr[i_lm⁺] + Ylm⁻ = Yr[i_lm⁻] + Yc[i_lm⁺] = (-1)^m * (Ylm⁺ + im * Ylm⁻) / sqrt(2) + Yc[i_lm⁻] = (Ylm⁺ - im * Ylm⁻) / sqrt(2) + end + end + return Yc +end + +function rand_sphere() + u = @SVector randn(3) + return u / norm(u) +end + +function rand_rot() + K = @SMatrix randn(3,3) + return exp(K - K') +end + + +## +# CASE 2: 4-correlations, L = 0 (revisited) +L = 0 +cc = Rot3DCoeffs(0) +# now we fix an ll = (l1, l2, l3) triple ask for all possible linear combinations +# of the tensor product basis Y[l1, m1] * Y[l2, m2] * Y[l3, m3] * Y[l4, m4] +# that are invariant under O(3) rotations. +ll = SA[3, 3, 2, 2, 2] +coeffs, MM = O3.re_basis(cc, ll) +nbas = size(coeffs, 1) +# coeffs = nbasis x length(MM) matrix +# MM = vector of (m1, m2, m3, m4) tuples +# for 4-correlations and higher, the number of possible couplings can be +# greater than one. An interesting result of this is that even if those +# basis functions as tensor products of Ylms are linearly independent, they +# need no longer be linearly independent once we impose permutation-invariance. + +""" +This implements a permutation-invariant and O(3) invariant function of +4 (or more) variables. +""" +function f(Rs, q::Integer; coeffs=coeffs, MM=MM) + real_basis = SphericalHarmonics(3) + Y = [ eval_cY(real_basis, 𝐫) for 𝐫 in Rs ] + A = sum(Y) # this is a permutation-invariant embedding + out = zero(eltype(A)) + for (c, mm) in zip(coeffs[q, :], MM) + ii = [SpheriCart.lm2idx(ll[α], mm[α]) for α in 1:4] + out += c * prod(A[ii]) + end + return real(out) +end + + +Rs = [rand_sphere() for _ in 1:4] +[ f(Rs, q) for q = 1:nbas ] + +# we can look for linear independence... + +function f_rand_batch(coeffs, MM, ntest) + nbas = size(coeffs, 1) + A = zeros(nbas, ntest) + for i = 1:ntest + Rs = [rand_sphere() for _ in 1:4] + for q = 1:nbas + A[q, i] = f(Rs, q; coeffs = coeffs) + end + end + return A +end + +A = f_rand_batch(coeffs, MM, 1_000) + +# the rank of this matrix is only 3, not nbas = 5! +rk = rank(A) +# 3 + +# this is in fact very clear from the SVD +svdvals(A) +# 5-element Vector{Float64}: +# 25.568548451339442 +# 6.754493909270821 +# 5.45667827344765 +# 5.398138581058422e-15 +# 1.6731699590390224e-15 + +## +# In ACE we use a semi-analytic construction to make the basis functions +# linearly indepdendent. +# in a full ACE code, this is a bit more complex since n channels are added +# to the story. This can be found here: +# https://github.com/ACEsuit/ACE1.jl/blob/8ac52d2128241a01d8b9a036f41b1d5106cbeb07/src/rpi/rotations3d.jl#L334 +# https://github.com/ACEsuit/ACE1.jl/blob/8ac52d2128241a01d8b9a036f41b1d5106cbeb07/src/rpi/rotations3d.jl#L348 + +# For simplicity, we can just use the SVD to construct a +# linearly independent basis. + +U, S, V = svd(A) +coeffs_ind = Diagonal(S[1:rk]) \ (U[:, 1:rk]' * coeffs) + +## +# now let's re-compute A + +ntest = 1_000 +A_ind = zeros(rk, ntest) +for i = 1:ntest + Rs = [rand_sphere() for _ in 1:4] + for q = 1:rk + A_ind[q, i] = f(Rs, q; coeffs = coeffs_ind) + end +end + +# the rank of this matrix is 3, which is what we expect +rk_ind = rank(A_ind) +# 3 + +# And we even scaled the coupling coeffs so that the singular values are +# now all close to 1. +svdvals(A_ind) +# 3-element Vector{Float64}: +# 0.997130965666211 +# 0.8853910366671397 +# 0.8668533817486788 \ No newline at end of file diff --git a/test/test_cg_vs_partialwavefunctions.jl b/test/test_cg_vs_partialwavefunctions.jl new file mode 100644 index 0000000..3ae8845 --- /dev/null +++ b/test/test_cg_vs_partialwavefunctions.jl @@ -0,0 +1,18 @@ +using PartialWaveFunctions +using RepLieGroups.O3: ClebschGordan + +@info("Testing the correctness of PartialWaveFunctions") +Lmax = 6 +for j1 in 1:Lmax + for m1 in -j1:j1 + for j2 in 1:Lmax + for m2 in -j2:j2 + for J in 1:Lmax + for M in -J:J + @test abs(clebschgordan(j1, m1, j2, m2, J, M, Float64) - PartialWaveFunctions.clebschgordan(j1, m1, j2, m2, J, M) < 1e-14) + end + end + end + end + end +end From 64d8b4495f46eb3163909c7c6af29d19fdaa9dd4 Mon Sep 17 00:00:00 2001 From: dussong Date: Thu, 31 Oct 2024 10:55:12 +0100 Subject: [PATCH 2/2] test RPI - debug --- Project.toml | 1 + src/O3_alternative.jl | 176 ++++++++++++++++++++++++++++++++++++++++- test/test_RI_basis.jl | 66 +++++++++++++++- test/test_RPI_basis.jl | 126 +++++++++++------------------ 4 files changed, 282 insertions(+), 87 deletions(-) diff --git a/Project.toml b/Project.toml index 6eef21a..b820cb9 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] StaticArrays = "1.5" diff --git a/src/O3_alternative.jl b/src/O3_alternative.jl index 37d10b0..5ade884 100644 --- a/src/O3_alternative.jl +++ b/src/O3_alternative.jl @@ -4,7 +4,7 @@ using PartialWaveFunctions using Combinatorics using LinearAlgebra -export re_basis_new +export re_basis_new, ri_basis_new, ind_corr_s1, ind_corr_s2, MatFmi, ML0 function CG(l,m,L,N) M=m[1]+m[2] @@ -50,6 +50,30 @@ function SetLl0(l,N) return set end +function SetLl(l,N,L) + set = Vector{Int64}[] + for k in abs(l[1]-l[2]):l[1]+l[2] + push!(set, [0; k]) + end + for k in 3:N-1 + setL=set + set=Vector{Int64}[] + for a in setL + for b in abs(a[k-1]-l[k]):a[k-1]+l[k] + push!(set, [a; b]) + end + end + end + setL=set + set=Vector{Int64}[] + for a in setL + if (abs.(a[N-1]-l[N]) <= L)&&(L <= (a[N-1]+l[N])) + push!(set, [a; L]) + end + end + return set +end + # Function that computes the set ML0 function ML0(l,N) setML = [[i] for i in -abs(l[1]):abs(l[1])] @@ -70,7 +94,29 @@ function ML0(l,N) return setML0 end -function re_basis_new(l) +# Function that computes the set ML (relative to equivariance L) +function ML(l,N,L) + setML = [[i] for i in -abs(l[1]):abs(l[1])] + for k in 2:N-1 + set = setML + setML = Vector{Int64}[] + for m in set + append!(setML, [m; lk] for lk in -abs(l[k]):abs(l[k]) ) + end + end + setML0=Vector{Int64}[] + for m in setML + s=sum(m) + for mn in -L-s:L-s + if abs(mn) < abs(l[N])+1 + push!(setML0, [m; mn]) + end + end + end + return setML0 +end + +function ri_basis_new(l) N=size(l,1) L=SetLl0(l,N) r=size(L,1) @@ -89,4 +135,130 @@ function re_basis_new(l) end end return U,M +end + +function re_basis_new(l,L) + N=size(l,1) + Ll=SetLl(l,N,L) + r=size(Ll,1) + if r==0 + return zeros(Float64, 0, 0) + else + setML0=ML(l,N,L) + sizeML0=length(setML0) + U=zeros(Float64, r, sizeML0) + M = Vector{Int64}[] + for (j,m) in enumerate(setML0) + push!(M,m) + for i in 1:r + U[i,j]=CG(l,m,Ll[i],N) + end + end + end + return U,M +end + + +# Function that computes the permutations that let n and l invariant +function Snl(N,n,l) + if n==n[1]*ones(N) + if l==l[1]*ones(N) + return permutations(1:N) + end + end + if N==1 + return Set([[1]]) + elseif (n[N-1],l[N-1])!=(n[N],l[N]) + S=Set() + Sn=Snl(N-1,n[1:N-1],l[1:N-1]) + for x in Sn + append!(x,[N]) + union!(S,Set([x])) + end + else + S=Set() + k=N + while (n[k-1],l[k-1])==(n[k],l[k]) && k>2 + k-=1 + end + if k==2 && (n[1],l[1])==(n[2],l[2]) + return Set(permutations(1:N)) + else + Sn=Snl(k-1,n[1:k-1],l[1:k-1]) + for x in Sn + for s in Set(permutations(k:N)) + y=copy(x) + append!(y,s) + union!(S,Set([y])) + end + end + end + end + return S +end + + +#Function that computes the set of classes using the set Ml0 and the possible permutations +function class(setML0,sigma,N,l) + setclass=Vector{Vector{Int64}}[] + pop!(setML0,zeros(Int64,N)) + while setML0!=Set() + x=pop!(setML0) + p=[x] + for s in sigma + y=x[s] + if y in setML0 + append!(p,[y]) + pop!(setML0,y) + end + end + append!(setclass,[p]) + end + setclasses=Vector{Vector{Int64}}[] + for x in setclass + for y in setclass + if x==y + if minimum(x)==minimum(-x) + if iseven(sum(l)) + append!(setclasses,[x]) + end + end + elseif minimum(x)==minimum(-y) + if y