diff --git a/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H b/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H index 3a4fe6845d0..f14e995a1f7 100644 --- a/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H +++ b/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H @@ -96,6 +96,8 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) MF ph(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory); MF sh(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory); + ph.setVal(RT(0.0)); + sh.setVal(RT(0.0)); MF sorig(ba, dm, ncomp, nghost, MFInfo(), factory); MF p (ba, dm, ncomp, nghost, MFInfo(), factory); @@ -104,14 +106,16 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) MF v (ba, dm, ncomp, nghost, MFInfo(), factory); MF t (ba, dm, ncomp, nghost, MFInfo(), factory); - // Compute residual r Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT::BCMode::Homogeneous); + // Then normalize Lp.normalize(amrlev, mglev, r); sorig.LocalCopy(sol,0,0,ncomp,nghost); rh.LocalCopy (r ,0,0,ncomp,nghost); + sol.setVal(RT(0.0)); + RT rnorm = norm_inf(r); const RT rnorm0 = rnorm; @@ -121,7 +125,7 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) } int ret = 0; iter = 1; - RT alpha, beta, omega, rho, rhTv; + RT rho_1 = 0, alpha = 0, omega = 0; if ( rnorm0 == 0 || rnorm0 < eps_abs ) { @@ -131,25 +135,31 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) << ", rnorm = " << rnorm << ", eps_abs = " << eps_abs << std::endl; } - sol.setVal(RT(0.0)); // Not sure why we set to 0 in this case return ret; } - rho = dotxy(rh,r); // Move this to here to eliminate usage of rho_1 - - p.LocalCopy(r,0,0,ncomp,nghost); // This is the true initialization. Move this to here to avoid iter==1 check every iteration - for (; iter <= maxiter; ++iter) { + const RT rho = dotxy(rh,r); if ( rho == 0 ) { ret = 1; break; } + if ( iter == 1 ) + { + p.LocalCopy(r,0,0,ncomp,nghost); + } + else + { + const RT beta = (rho/rho_1)*(alpha/omega); + MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v + MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p + } ph.LocalCopy(p,0,0,ncomp,nghost); Lp.apply(amrlev, mglev, v, ph, MLLinOpT::BCMode::Homogeneous, MLLinOpT::StateMode::Correction); Lp.normalize(amrlev, mglev, v); - rhTv = dotxy(rh,v); + RT rhTv = dotxy(rh,v); if ( rhTv != RT(0.0) ) { alpha = rho/rhTv; @@ -158,9 +168,8 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) { ret = 2; break; } - MF::Saxpy(sol, alpha, ph, 0, 0, ncomp, nghost); // sol += alpha * ph - MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r -= alpha * v + MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r -= alpha * v rnorm = norm_inf(r); @@ -197,7 +206,7 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) ret = 3; break; } MF::Saxpy(sol, omega, sh, 0, 0, ncomp, nghost); // sol += omega * sh - MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r -= omega * t + MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r -= omega * t rnorm = norm_inf(r); @@ -215,14 +224,7 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) { ret = 4; break; } - // rho_1 = rho; // no need to save old rho since we only need rhTv to compute beta - - rho = dotxy(rh,r); - - beta = rho / (rhTv * omega); // This is a result of alpha = rho_1 / rhTv - - MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v - MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p + rho_1 = rho; } if ( verbose > 0 ) @@ -241,9 +243,14 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) ret = 8; } - if ( ( ret != 0 && ret != 8 ) || (rnorm > rnorm0) ) + if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) ) { - sol.LocalCopy(sorig,0,0,ncomp,nghost); + sol.LocalAdd(sorig, 0, 0, ncomp, nghost); + } + else + { + sol.setVal(RT(0.0)); + sol.LocalAdd(sorig, 0, 0, ncomp, nghost); } return ret; @@ -262,15 +269,19 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) const auto& factory = sol.Factory(); MF p(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory); + p.setVal(RT(0.0)); MF sorig(ba, dm, ncomp, nghost, MFInfo(), factory); MF r (ba, dm, ncomp, nghost, MFInfo(), factory); + MF z (ba, dm, ncomp, nghost, MFInfo(), factory); MF q (ba, dm, ncomp, nghost, MFInfo(), factory); sorig.LocalCopy(sol,0,0,ncomp,nghost); Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT::BCMode::Homogeneous); + sol.setVal(RT(0.0)); + RT rnorm = norm_inf(r); const RT rnorm0 = rnorm; @@ -279,6 +290,7 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) amrex::Print() << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n'; } + RT rho_1 = 0; int ret = 0; iter = 1; @@ -289,20 +301,32 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) << ", rnorm = " << rnorm << ", eps_abs = " << eps_abs << std::endl; } - sol.setVal(RT(0.0)); // Not sure why we do this in this case return ret; } - RT alpha, beta, rho, rho_1, pw; - - rho = dotxy(r,r); - p.LocalCopy(r,0,0,ncomp,nghost); - for (; iter <= maxiter; ++iter) { + z.LocalCopy(r,0,0,ncomp,nghost); + + RT rho = dotxy(z,r); + + if ( rho == 0 ) + { + ret = 1; break; + } + if (iter == 1) + { + p.LocalCopy(z,0,0,ncomp,nghost); + } + else + { + RT beta = rho/rho_1; + MF::Xpay(p, beta, z, 0, 0, ncomp, nghost); // p = z + beta * p + } Lp.apply(amrlev, mglev, q, p, MLLinOpT::BCMode::Homogeneous, MLLinOpT::StateMode::Correction); - pw = dotxy(p,q); + RT alpha; + RT pw = dotxy(p,q); if ( pw != RT(0.0)) { alpha = rho/pw; @@ -333,16 +357,7 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; } - if ( rho == 0 ) - { - ret = 1; break; - } - rho_1 = rho; - rho = dotxy(r,r); - - beta = rho/rho_1; - MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p } if ( verbose > 0 ) @@ -353,7 +368,7 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) << rnorm/(rnorm0) << '\n'; } - if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs ) + if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs ) { if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) { amrex::Warning("MLCGSolver_cg: failed to converge!"); @@ -361,9 +376,14 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) ret = 8; } - if ( ( ret != 0 && ret != 8 ) || (rnorm > rnorm0) ) + if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) ) + { + sol.LocalAdd(sorig, 0, 0, ncomp, nghost); + } + else { - sol.LocalCopy(sorig,0,0,ncomp,nghost); + sol.setVal(RT(0.0)); + sol.LocalAdd(sorig, 0, 0, ncomp, nghost); } return ret;