Skip to content

Commit

Permalink
Merge pull request #88 from ACEsuit/aspselect
Browse files Browse the repository at this point in the history
Post fit selection from asp path
  • Loading branch information
cortner committed Sep 14, 2024
2 parents 391ec74 + ce319ec commit 8062ad9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
40 changes: 31 additions & 9 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,28 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
for p in new_tracer ]
end

xs, in = select_solution(new_post, solver, Aval, yval)
tracer_final = _add_errors(new_post, Aval, yval)
xs, in = asp_select(tracer_final, solver.select)

return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_post[in][:solution]).nzind) )
return Dict( "C" => xs,
"path" => tracer_final, )
end


function select_solution(tracer, solver, A, y)
if solver.select == :final
function _add_errors(tracer, A, y)
rtN = sqrt(length(y))
return [ ( solution = t.solution, λ = t.λ, σ = t.σ,
rmse = norm(A * t.solution - y) / rtN )
for t in tracer ]
end

asp_select(D::Dict, select) = asp_select(D["path"], select)

function asp_select(tracer, select)
if select == :final
criterion = :final
else
criterion, p = solver.select
criterion, p = select
end

if criterion == :final
Expand All @@ -108,12 +117,12 @@ function select_solution(tracer, solver, A, y)
elseif criterion == :bysize
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
tracer) - 1
threshold = 1.0
threshold = 1.0
else
error("Unknown selection criterion: $criterion")
end

errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
errors = [ t.rmse for t in tracer[1:maxind] ]
min_error = minimum(errors)
for (i, error) in enumerate(errors)
if error <= threshold * min_error
Expand All @@ -140,3 +149,16 @@ function post_asp_tsvd(path, At, yt, Av, yv)

return _post.(path)
end

# TODO: revisit this idea. Maybe we do want to keep this, not as `select`
# but as `solve`. But if we do, then it might be nice to be able to
# extend the path somehow. For now I'm removing it since I don't see
# the immediate need yet. Just calling asp_select is how I would normally
# use this.
#
# function select(tracer, solver, A, y) #can be called by the user to warm-start the selection
# xs, in = select_solution(tracer, solver, A, y)
# return Dict("C" => xs,
# "path" => tracer,
# "nnzs" => length( (tracer[in][:solution]).nzind) )
# end
43 changes: 38 additions & 5 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ACEfit
using LinearAlgebra, Random, Test
using LinearAlgebra, Random, Test

##

Expand Down Expand Up @@ -47,7 +47,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)
Expand All @@ -60,7 +60,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)
Expand Down Expand Up @@ -91,7 +91,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
C_tsvd = results_tsvd["C"]
C = results["C"]

@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)
if norm(A * C_tsvd - y)< norm(A * C - y)
Expand All @@ -106,7 +106,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
results = ACEfit.solve(solver, At, yt, Av, yv)
C_tsvd = results_tsvd["C"]
C = results["C"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)

Expand All @@ -117,3 +117,36 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
end
end

##

# Testing the "select" function
solver_final = ACEfit.ASP(
P = I,
select = :final,
tsvd = false,
nstore = 100,
loglevel = 0
)

results_final = ACEfit.solve(solver_final, At, yt, Av, yv)
tracer_final = results_final["path"]

# Warm-start the solver using the tracer from the final iteration
# select best solution with <= 73 non-zero entries
select = (:bysize, 73)
C_select, _ = ACEfit.asp_select(tracer_final, select)
@test( length(C_select.nzind) <= 73 )

# Check if starting the solver initially with (:bysize, 73) gives the same result
solver_bysize = ACEfit.ASP(
P = I,
select = (:bysize, 73),
tsvd = false,
nstore = 100,
loglevel = 0
)

results_bysize = ACEfit.solve(solver_bysize, At, yt, Av, yv)
@test results_bysize["C"] == C_select # works


0 comments on commit 8062ad9

Please sign in to comment.