diff --git a/src/optimization/gauss_newton.jl b/src/optimization/gauss_newton.jl index d6709cd..fb58673 100644 --- a/src/optimization/gauss_newton.jl +++ b/src/optimization/gauss_newton.jl @@ -70,7 +70,12 @@ function opt_gauss_newton!( break end - d = solve_linlsqr!(Jac, res, linlsqr, droptol) + if (linlsqr isa LinLsqrSolve) # Refactored behaviour + d = solve_linlsqr!(Jac, res, linlsqr) + else # Old behavior + d = solve_linlsqr!(Jac, res, linlsqr, droptol) + end + x = get_coeffs(graph, cref) x -= γ0 * d set_coeffs!(graph, x, cref) diff --git a/src/optimization/opt_common.jl b/src/optimization/opt_common.jl index ddd9a2e..fb15a13 100644 --- a/src/optimization/opt_common.jl +++ b/src/optimization/opt_common.jl @@ -27,6 +27,68 @@ function adjust_for_errtype!(A, b, objfun_vals, errtype) return (A, b) end + +abstract type LinLsqrSolve end +struct BackslashLinLsqrSolve <: LinLsqrSolve; +end +function solve_linlsqr!(A,b,::BackslashLinLsqrSolve) + d = A \ b +end + +struct RealBackslashLinLsqrSolve <: LinLsqrSolve; end +function solve_linlsqr!(A,b,::RealBackslashLinLsqrSolve) + d = vcat(real(A), imag(A)) \ vcat(real(b), imag(b)) +end + +struct NormEqLinLsqrSolve <: LinLsqrSolve; end +function solve_linsqr!(A,b,::NormEqLinLsqrSolve) + d = (A' * A) \ (A' * b) +end + +struct RealNormEqLinLsqrSolve <: LinLsqrSolve; end +function solve_linlsqr!(A,b,::RealNormEqLinLsqrSolve) + Ar = real(A) + Ai = imag(A) + br = real(b) + bi = imag(b) + d = (Ar' * Ar + Ai' * Ai) \ (Ar' * br + Ai' * bi) +end + +struct SVDLsqrSolve <: LinLsqrSolve; + tp + droptol + fixed_rank +end +function solve_linlsqr!(A,b,solver::SVDLsqrSolve) + + if (solver.tp == :real_svd) + A = vcat(real(A), imag(A)) + b = vcat(real(b), imag(b)) + end + if (eltype(A) == BigFloat || eltype(A) == Complex{BigFloat}) + Sfact = svd!(A; full = false, alg = nothing) + else + Sfact = svd(A) + end + d = Sfact.S + # Use pseudoinverse if droptol>0 + Z = (d / d[1]) .< solver.droptol; + II=findall((!).(Z)); + nonzero=II[1:Int(min(solver.fixed_rank,length(II)))]; + + # Only select index nonzero + dinv=zeros(eltype(d),size(d)); + dinv[1:length(nonzero)] = 1 ./ d[1:length(nonzero)]; + + # No explicit construction, only multiplication + # JJ0=Sfact.U*Diagonal(d)*Sfact.Vt + d = Sfact.V * (dinv .* (Sfact.U' * b)) + + +end + + + """ d = solve_linlsqr!(A, b, linlsqr, droptol)