Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 14, 2024
1 parent 44418b6 commit ce319ec
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
45 changes: 30 additions & 15 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,28 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
for p in new_tracer ]
end

xs, in = select_solution(new_post, solver, Aval, yval)
tracer_final = _add_errors(new_post, Aval, yval)
xs, in = asp_select(tracer_final, solver.select)

return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_post[in][:solution]).nzind) )
return Dict( "C" => xs,
"path" => tracer_final, )
end


function select_solution(tracer, solver, A, y)
if solver.select == :final
function _add_errors(tracer, A, y)
rtN = sqrt(length(y))
return [ ( solution = t.solution, λ = t.λ, σ = t.σ,
rmse = norm(A * t.solution - y) / rtN )
for t in tracer ]
end

asp_select(D::Dict, select) = asp_select(D["path"], select)

function asp_select(tracer, select)
if select == :final
criterion = :final
else
criterion, p = solver.select
criterion, p = select
end

if criterion == :final
Expand All @@ -108,12 +117,12 @@ function select_solution(tracer, solver, A, y)
elseif criterion == :bysize
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
tracer) - 1
threshold = 1.0
threshold = 1.0
else
error("Unknown selection criterion: $criterion")
end

errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
errors = [ t.rmse for t in tracer[1:maxind] ]
min_error = minimum(errors)
for (i, error) in enumerate(errors)
if error <= threshold * min_error
Expand Down Expand Up @@ -141,9 +150,15 @@ function post_asp_tsvd(path, At, yt, Av, yv)
return _post.(path)
end

function select(tracer, solver, A, y) #can be called by the user to warm-start the selection
xs, in = select_solution(tracer, solver, A, y)
return Dict("C" => xs,
"path" => tracer,
"nnzs" => length( (tracer[in][:solution]).nzind) )
end
# TODO: revisit this idea. Maybe we do want to keep this, not as `select`
# but as `solve`. But if we do, then it might be nice to be able to
# extend the path somehow. For now I'm removing it since I don't see
# the immediate need yet. Just calling asp_select is how I would normally
# use this.
#
# function select(tracer, solver, A, y) #can be called by the user to warm-start the selection
# xs, in = select_solution(tracer, solver, A, y)
# return Dict("C" => xs,
# "path" => tracer,
# "nnzs" => length( (tracer[in][:solution]).nzind) )
# end
27 changes: 11 additions & 16 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ACEfit
using LinearAlgebra, Random, Test
using LinearAlgebra, Random, Test

##

Expand Down Expand Up @@ -47,7 +47,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)
Expand All @@ -60,7 +60,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)
Expand Down Expand Up @@ -91,7 +91,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
C_tsvd = results_tsvd["C"]
C = results["C"]

@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)
if norm(A * C_tsvd - y)< norm(A * C - y)
Expand All @@ -106,7 +106,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
results = ACEfit.solve(solver, At, yt, Av, yv)
C_tsvd = results_tsvd["C"]
C = results["C"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)

Expand All @@ -117,6 +117,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
end
end

##

# Testing the "select" function
solver_final = ACEfit.ASP(
Expand All @@ -131,16 +132,10 @@ results_final = ACEfit.solve(solver_final, At, yt, Av, yv)
tracer_final = results_final["path"]

# Warm-start the solver using the tracer from the final iteration
solver_warmstart = ACEfit.ASP(
P = I,
select = (:bysize, 73),
tsvd = false,
nstore = 100,
loglevel = 0
)

results_warmstart = ACEfit.select(tracer_final, solver_warmstart, Av, yv)
C_warmstart = results_warmstart["C"]
# select best solution with <= 73 non-zero entries
select = (:bysize, 73)
C_select, _ = ACEfit.asp_select(tracer_final, select)
@test( length(C_select.nzind) <= 73 )

# Check if starting the solver initially with (:bysize, 73) gives the same result
solver_bysize = ACEfit.ASP(
Expand All @@ -152,6 +147,6 @@ solver_bysize = ACEfit.ASP(
)

results_bysize = ACEfit.solve(solver_bysize, At, yt, Av, yv)
@test results_bysize["C"] == C_warmstart # works
@test results_bysize["C"] == C_select # works


0 comments on commit ce319ec

Please sign in to comment.