Skip to content

Commit

Permalink
Merge pull request #74 from ACEsuit/truncsvd
Browse files Browse the repository at this point in the history
Truncated SVD
  • Loading branch information
cortner authored Mar 28, 2024
2 parents 6c3a892 + 8f6ea3e commit 2070804
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 10 deletions.
47 changes: 47 additions & 0 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using LinearAlgebra: qr, I, norm
using LowRankApprox: pqrfact
using IterativeSolvers
using .BayesianLinear
using LinearAlgebra: SVD, svd

@doc raw"""
`struct QR` : linear least squares solver, using standard QR factorisation;
Expand Down Expand Up @@ -148,3 +149,49 @@ function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000)
end

# solve(solver::SKLEARN_ARD, ...) is implemented in ext/

@doc raw"""
`struct TruncatedSVD` : linear least squares solver for approximately solving
```math
θ = \arg\min \| A \theta - y \|^2
```
- transform $\tilde\theta = P \theta$
- perform svd on $A P^{-1}$
- truncate svd at `rtol`, i.e. keep only the components for which $\sigma_i \geq {\rm rtol} \max \sigma_i$
- Compute $\tilde\theta$ from via pseudo-inverse
- Reverse transformation $\theta = P^{-1} \tilde\theta$
Constructor
```julia
ACEfit.TruncatedSVD(; rtol = 1e-9, P = I)
```
where
* `rtol` : relative tolerance
* `P` : right-preconditioner / tychonov operator
"""
struct TruncatedSVD
rtol::Number
P::Any
end

TruncatedSVD(; rtol = 1e-9, P = I) = TruncatedSVD(rtol, P)

function trunc_svd(USV::SVD, Y, rtol)
U, S, V = USV # svd(A)
Ikeep = findall(x -> x > rtol, S ./ maximum(S))
U1 = @view U[:, Ikeep]
S1 = S[Ikeep]
V1 = @view V[:, Ikeep]
return V1 * (S1 .\ (U1' * Y))
end

function solve(solver::TruncatedSVD, A, y)
AP = A / solver.P
print("Truncted SVD: perform svd ... ")
USV = svd(AP)
print("done. truncation ... ")
θP = trunc_svd(USV, y, solver.rtol)
println("done.")
return Dict{String, Any}("C" => solver.P \ θP)
end

52 changes: 42 additions & 10 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ using PythonCall
@info("Test Solver on overdetermined system")
Nobs = 10_000
Nfeat = 100
A = randn(Nobs, Nfeat) / sqrt(Nobs)
y = randn(Nobs)
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
U, S1, V = svd(A1)
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
A = U * Diagonal(S) * V'
c_ref = randn(Nfeat)
y = A * c_ref + 1e-3 * randn(Nobs) / sqrt(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))

@info(" ... QR")
Expand All @@ -16,66 +20,94 @@ results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... regularised QR, λ = 1.0")
solver = ACEfit.QR(lambda = 1e0, P = P)
@info(" ... regularised QR, λ = 1e-5")
solver = ACEfit.QR(lambda = 1e-5, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... regularised QR, λ = 10.0")
solver = ACEfit.QR(lambda = 1e1, P = P)
@info(" ... regularised QR, λ = 1e-2")
solver = ACEfit.QR(lambda = 1e-2, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 1e-15")
solver = ACEfit.RRQR(rtol = 1e-15, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 0.5")
solver = ACEfit.RRQR(rtol = 0.5, P = P)

@info(" ... RRQR, rtol = 1e-5")
solver = ACEfit.RRQR(rtol = 1e-5, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 0.99")
solver = ACEfit.RRQR(rtol = 0.99, P = P)
@info(" ... RRQR, rtol = 1e-3")
solver = ACEfit.RRQR(rtol = 1e-3, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... LSQR")
solver = ACEfit.LSQR(damp = 0, atol = 1e-6)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... SKLEARN_BRR")
solver = ACEfit.SKLEARN_BRR()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... SKLEARN_ARD")
solver = ACEfit.SKLEARN_ARD()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... BLR")
solver = ACEfit.BLR()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... TruncatedSVD(; rtol = 1e-5)")
solver = ACEfit.TruncatedSVD(; rtol = 1e-5)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... TruncatedSVD(; rtol = 1e-4)")
solver = ACEfit.TruncatedSVD(; rtol=1e-4)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

0 comments on commit 2070804

Please sign in to comment.