Skip to content

Commit

Permalink
asp bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 12, 2024
1 parent 929bf85 commit 0910d0d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
20 changes: 9 additions & 11 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ subject to
```
### Constructor Keyword arguments
```julia
ACEfit.ASP(; P = I, select = (:byerror, 1.0),
ACEfit.ASP(; P = I, select = (:byerror, 1.0), tsvd = false, nstore=100,
params...)
```
Expand Down Expand Up @@ -42,33 +43,30 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
```
* `A` : `m`-by-`n` design matrix. (required)
* `b` : `m`-vector. (required)
* `Aval = nothing` : `p`-by-`n` validation matrix (only for `validate` mode).
* `bval = nothing` : `p`- validation vector (only for `validate` mode).
* `Aval = nothing` : `p`-by-`n` validation matrix
* `bval = nothing` : `p`- validation vector
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
then the solver will use this separate validation set instead of the training
set to select the best solution along the model path.
# """

"""
struct ASP
P
select
mode::Symbol
tsvd::Bool
nstore::Integer
params
end

function ASP(; P = I, select, mode=:train, tsvd=false, nstore=100, params...)
return ASP(P, select, mode, tsvd, nstore, params)
function ASP(; P = I, select, tsvd=false, nstore=100, params...)
return ASP(P, select, tsvd, nstore, params)
end

function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P
AvalP = Aval / solver.P

tracer = asp_homotopy(AP, y; solver.params...)
tracer = asp_homotopy(AP, y; solver.params..., traceFlag = true)

q = length(tracer)
every = max(1, q / solver.nstore)
Expand All @@ -89,7 +87,7 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)

return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
"nnzs" => length( (new_post[in][:solution]).nzind) )
end


Expand Down
15 changes: 9 additions & 6 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]


for (nstore, n1) in [ (20, 21), (100, 101), (200, 165)]
solver = ACEfit.ASP(P=I, select = :final, nstore = nstore, loglevel=0, traceFlag=true)
solver = ACEfit.ASP(; P=I, select = :final, nstore = nstore, loglevel=0)
results = ACEfit.solve(solver, A, y)
@test length(results["path"]) == n1
end

##

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
solver = ACEfit.ASP(P=I, select = select, loglevel=0)
# without validation
results = ACEfit.solve(solver, A, y)
C = results["C"]
Expand Down Expand Up @@ -77,11 +80,11 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver_tsvd = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=true,
nstore=100, loglevel=0, traceFlag=true)
solver_tsvd = ACEfit.ASP(P=I, select=select, tsvd=true,
nstore=100, loglevel=0)

solver = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=false,
nstore=100, loglevel=0, traceFlag=true)
solver = ACEfit.ASP(P=I, select=select, tsvd=false,
nstore=100, loglevel=0)
# without validation
results_tsvd = ACEfit.solve(solver_tsvd, A, y)
results = ACEfit.solve(solver, A, y)
Expand Down

0 comments on commit 0910d0d

Please sign in to comment.