From f892f486af301a890eecb44744df6e58df06aba8 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 11 Sep 2024 14:53:24 -0700 Subject: [PATCH 1/2] add svd with validation set --- src/asp.jl | 49 ++++++-------------------------------- src/solvers.jl | 34 ++++++++++++++++++++++++++ test/test_linearsolvers.jl | 21 ++++++++++++---- 3 files changed, 58 insertions(+), 46 deletions(-) diff --git a/src/asp.jl b/src/asp.jl index 88a2aa7..4f690c4 100644 --- a/src/asp.jl +++ b/src/asp.jl @@ -66,20 +66,23 @@ end function solve(solver::ASP, A, y, Aval=A, yval=y) # Apply preconditioning AP = A / solver.P + AvalP = Aval / solver.P tracer = asp_homotopy(AP, y; solver.params...) q = length(tracer) every = max(1, q ÷ solver.nstore) istore = unique([1:every:q; q]) - new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 ) + new_tracer = [ (solution = tracer[i][1], λ = tracer[i][2], σ = 0.0 ) for i in istore ] if solver.tsvd # Post-processing if tsvd is true - post = post_asp_tsvd(new_tracer, A, y, Aval, yval) - new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ] + post = post_asp_tsvd(new_tracer, AP, y, AvalP, yval) + new_post = [ (solution = solver.P \ p.θ, λ = p.λ, σ = p.σ) + for p in post ] else - new_post = new_tracer + new_post = [ (solution = solver.P \ p.solution, λ = p.λ, σ = 0.0) + for p in new_tracer ] end xs, in = select_solution(new_post, solver, Aval, yval) @@ -124,34 +127,6 @@ function select_solution(tracer, solver, A, y) end - -using SparseArrays - -function solve_tsvd(At, yt, Av, yv) - Ut, Σt, Vt = svd(At); zt = Ut' * yt - Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv - @assert issorted(Σt, rev=true) - - Rv_Vt = Rv * Vt - - θv = zeros(size(Av, 2)) - θv[1] = zt[1] / Σt[1] - rv = Rv_Vt[:, 1] * θv[1] - zv - - tsvd_errs = Float64[] - push!(tsvd_errs, norm(rv)) - - for k = 2:length(Σt) - θv[k] = zt[k] / Σt[k] - rv += Rv_Vt[:, k] * θv[k] - push!(tsvd_errs, norm(rv)) - end - - imin = argmin(tsvd_errs) - θv[imin+1:end] .= 0 - return Vt * θv, Σt[imin] -end - function post_asp_tsvd(path, At, yt, Av, yv) Qt, Rt = qr(At); zt = Matrix(Qt)' * yt Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv @@ -166,14 +141,4 @@ function post_asp_tsvd(path, At, yt, Av, yv) end return _post.(path) - -# post = [] -# for (θ, λ) in path -# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end -# inz = θ.nzind -# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv) -# θ2 = copy(θ); θ2[inz] .= θ1 -# push!(post, (θ = θ2, λ = λ, σ = σ)) -# end -# return identity.(post) end diff --git a/src/solvers.jl b/src/solvers.jl index 9867271..fac6ee2 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -196,3 +196,37 @@ function solve(solver::TruncatedSVD, A, y) return Dict{String, Any}("C" => solver.P \ θP) end + +# ------------ Truncated SVD with tol specified by validation set ------------ + +function solve_tsvd(At, yt, Av, yv) + Ut, Σt, Vt = svd(At); zt = Ut' * yt + Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv + @assert issorted(Σt, rev=true) + + Rv_Vt = Rv * Vt + + θv = zeros(size(Av, 2)) + θv[1] = zt[1] / Σt[1] + rv = Rv_Vt[:, 1] * θv[1] - zv + + tsvd_errs = Float64[] + push!(tsvd_errs, norm(rv)) + + for k = 2:length(Σt) + θv[k] = zt[k] / Σt[k] + rv += Rv_Vt[:, k] * θv[k] + push!(tsvd_errs, norm(rv)) + end + + imin = argmin(tsvd_errs) + θv[imin+1:end] .= 0 + return Vt * θv, Σt[imin] + end + + +function solve(solver::TruncatedSVD, At, yt, Av, yv) + # make a function barrier because solver.P is not inferred + θ, σ = solve_tsvd(At / solver.P, yt, Av / solver.P, yv) + return Dict{String, Any}("C" => solver.P \ θ, "σ" => σ) +end diff --git a/test/test_linearsolvers.jl b/test/test_linearsolvers.jl index 9c03b53..c0087ba 100644 --- a/test/test_linearsolvers.jl +++ b/test/test_linearsolvers.jl @@ -1,8 +1,5 @@ -using ACEfit -using LinearAlgebra, Random, Test -using Random -using PythonCall +using ACEfit, LinearAlgebra, Random, Test, PythonCall ## @@ -168,3 +165,19 @@ C = results["C"] @test norm(A * C - y) < 10 * epsn @test norm(C - c_ref) < 1 + +## + +@info("Truncated SVD with validation") +solver = ACEfit.TruncatedSVD(; rtol = 0.0) +At = A[1:8000, :] +yt = y[1:8000] +Av = A[8001:end, :] +yv = y[8001:end] +results_v = ACEfit.solve(solver, At, yt, Av, yv) +@show err_v = norm(Av * results_v["C"] - yv) +@show err = norm(Av * results["C"] - yv) +@test err_v <= err +@show norm(results_v["C"] - c_ref) +@show norm(results["C"] - c_ref) +@test norm(results_v["C"] - c_ref) < 1e-2 \ No newline at end of file From 56dcdcceda84ee460a84adf7055b01efdd31e954 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 11 Sep 2024 15:04:30 -0700 Subject: [PATCH 2/2] asp bugfix --- src/asp.jl | 4 ++-- test/test_asp.jl | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/asp.jl b/src/asp.jl index 4f690c4..c5a6cbb 100644 --- a/src/asp.jl +++ b/src/asp.jl @@ -71,8 +71,8 @@ function solve(solver::ASP, A, y, Aval=A, yval=y) tracer = asp_homotopy(AP, y; solver.params...) q = length(tracer) - every = max(1, q ÷ solver.nstore) - istore = unique([1:every:q; q]) + every = max(1, q / solver.nstore) + istore = unique(round.(Int, [1:every:q; q])) new_tracer = [ (solution = tracer[i][1], λ = tracer[i][2], σ = 0.0 ) for i in istore ] diff --git a/test/test_asp.jl b/test/test_asp.jl index 4cacb64..e607bed 100644 --- a/test/test_asp.jl +++ b/test/test_asp.jl @@ -1,6 +1,5 @@ using ACEfit using LinearAlgebra, Random, Test -using Random ## @@ -29,6 +28,12 @@ Av = A[val_indices,:] yt = y[train_indices] yv = y[val_indices] +for (nstore, n1) in [ (20, 21), (100, 101), (200, 165)] + solver = ACEfit.ASP(P=I, select = :final, nstore = nstore, loglevel=0, traceFlag=true) + results = ACEfit.solve(solver, A, y) + @test length(results["path"]) == n1 +end + for (select, tolr, tolc) in [ (:final, 10*epsn, 1), ( (:byerror,1.3), 10*epsn, 1), ( (:bysize,73), 1, 10) ]