Skip to content

Commit

Permalink
changed naming, modified tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Mar 23, 2024
1 parent 2e18644 commit 8fb9270
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
24 changes: 17 additions & 7 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,30 @@ end
# solve(solver::SKLEARN_ARD, ...) is implemented in ext/

@doc raw"""
`struct Truncated_SVD` : linear least squares solver
`struct TruncatedSVD` : linear least squares solver for approximately solving
```math
θ = \arg\min \| A P^{-1} \theta - y \|^2
θ = \arg\min \| A \theta - y \|^2
```
- transform $\tilde\theta = P \theta$
- perform svd on $A P^{-1}$
- truncate svd at `rtol`, i.e. keep only the components for which $\sigma_i \geq {\rm rtol} \max \sigma_i$
- Compute $\tilde\theta$ from via pseudo-inverse
- Reverse transformation $\theta = P^{-1} \tilde\theta$
Constructor
```julia
ACEfit.Truncated_SVD(; lambda = 0.0, P = nothing)
ACEfit.TruncatedSVD(; rtol = 1e-9, P = I)
```
where
* `rtol` : relative tolerance
* `P` : right-preconditioner / tychonov operator
"""
struct Truncated_SVD
struct TruncatedSVD
rtol::Number
P::Any
end

Truncated_SVD(; rtol = 1e-9, P = I) = Truncated_SVD(rtol, P)
TruncatedSVD(; rtol = 1e-9, P = I) = TruncatedSVD(rtol, P)

function trunc_svd(USV::SVD, Y, rtol)
U, S, V = USV # svd(A)
Expand All @@ -179,9 +185,13 @@ function trunc_svd(USV::SVD, Y, rtol)
return V1 * (S1 .\ (U1' * Y))
end

function solve(solver::Truncated_SVD, A, y)
function solve(solver::TruncatedSVD, A, y)
AP = A / solver.P
θP = trunc_svd(svd(AP), y, solver.rtol)
print("Truncted SVD: perform svd ... ")
USV = svd(AP)
print("done. truncation ... ")
θP = trunc_svd(USV, y, solver.rtol)
println("done.")
return Dict{String, Any}("C" => solver.P \ θP)
end

49 changes: 37 additions & 12 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ using PythonCall
@info("Test Solver on overdetermined system")
Nobs = 10_000
Nfeat = 100
A = randn(Nobs, Nfeat) / sqrt(Nobs)
y = randn(Nobs)
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
U, S1, V = svd(A)
S = 1e-4 .+ ((S .- S[end]) / (S[1] - S[end])).^2
A = U * Diagonal(S) * V'
c_ref = randn(Nfeat)
y = A * c_ref + 1e-3 * randn(Nobs) / sqrt(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))

@info(" ... QR")
Expand All @@ -16,73 +20,94 @@ results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... regularised QR, λ = 1.0")
solver = ACEfit.QR(lambda = 1e0, P = P)
@info(" ... regularised QR, λ = 1e-5")
solver = ACEfit.QR(lambda = 1e-5, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... regularised QR, λ = 10.0")
solver = ACEfit.QR(lambda = 1e1, P = P)
@info(" ... regularised QR, λ = 1e-2")
solver = ACEfit.QR(lambda = 1e-2, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 1e-15")
solver = ACEfit.RRQR(rtol = 1e-15, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 0.5")
solver = ACEfit.RRQR(rtol = 0.5, P = P)

@info(" ... RRQR, rtol = 1e-5")
solver = ACEfit.RRQR(rtol = 1e-5, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... RRQR, rtol = 0.99")
solver = ACEfit.RRQR(rtol = 0.99, P = P)
@info(" ... RRQR, rtol = 1e-3")
solver = ACEfit.RRQR(rtol = 1e-3, P = P)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... LSQR")
solver = ACEfit.LSQR(damp = 0, atol = 1e-6)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... SKLEARN_BRR")
solver = ACEfit.SKLEARN_BRR()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... SKLEARN_ARD")
solver = ACEfit.SKLEARN_ARD()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... BLR")
solver = ACEfit.BLR()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... Truncated_SVD")
solver = ACEfit.Truncated_SVD()
@info(" ... TruncatedSVD(; rtol = 1e-5)")
solver = ACEfit.TruncatedSVD(; rtol = 1e-5)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... TruncatedSVD(; rtol = 1e-4)")
solver = ACEfit.TruncatedSVD(; rtol=1e-4)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

0 comments on commit 8fb9270

Please sign in to comment.