Skip to content

Commit

Permalink
Avoid use of z MF in solve_cg
Browse files Browse the repository at this point in the history
  • Loading branch information
eebasso committed Nov 2, 2023
1 parent 6eaddb1 commit a3fc3e1
Showing 1 changed file with 61 additions and 43 deletions.
104 changes: 61 additions & 43 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,27 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

MF ph(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory);
MF sh(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory);
ph.setVal(RT(0.0));
sh.setVal(RT(0.0));

MF sorig(ba, dm, ncomp, nghost, MFInfo(), factory);
MF p (ba, dm, ncomp, nghost, MFInfo(), factory);
MF r (ba, dm, ncomp, nghost, MFInfo(), factory);
MF s (ba, dm, ncomp, nghost, MFInfo(), factory);
MF rh (ba, dm, ncomp, nghost, MFInfo(), factory);
MF v (ba, dm, ncomp, nghost, MFInfo(), factory);
MF t (ba, dm, ncomp, nghost, MFInfo(), factory);

// Compute residual r
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

// Then normalize
Lp.normalize(amrlev, mglev, r);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
rh.LocalCopy (r ,0,0,ncomp,nghost);

sol.setVal(RT(0.0));

RT rnorm = norm_inf(r);
const RT rnorm0 = rnorm;

Expand All @@ -121,7 +126,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
}
int ret = 0;
iter = 1;
RT alpha, beta, omega, rho, rhTv;
RT rho_1 = 0, alpha = 0, omega = 0;

if ( rnorm0 == 0 || rnorm0 < eps_abs )
{
Expand All @@ -131,25 +136,31 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
<< ", rnorm = " << rnorm
<< ", eps_abs = " << eps_abs << std::endl;
}
sol.setVal(RT(0.0)); // Not sure why we set to 0 in this case
return ret;
}

rho = dotxy(rh,r); // Move this to here to eliminate usage of rho_1

p.LocalCopy(r,0,0,ncomp,nghost); // This is the true initialization. Move this to here to avoid iter==1 check every iteration

for (; iter <= maxiter; ++iter)
{
const RT rho = dotxy(rh,r);
if ( rho == 0 )
{
ret = 1; break;
}
if ( iter == 1 )
{
p.LocalCopy(r,0,0,ncomp,nghost);
}
else
{
const RT beta = (rho/rho_1)*(alpha/omega);
MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
}
ph.LocalCopy(p,0,0,ncomp,nghost);
Lp.apply(amrlev, mglev, v, ph, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, v);

rhTv = dotxy(rh,v);
RT rhTv = dotxy(rh,v);
if ( rhTv != RT(0.0) )
{
alpha = rho/rhTv;
Expand All @@ -158,11 +169,10 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 2; break;
}

MF::Saxpy(sol, alpha, ph, 0, 0, ncomp, nghost); // sol += alpha * ph
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r -= alpha * v
MF::LinComb(s, RT(1.0), r, 0, -alpha, v, 0, 0, ncomp, nghost); // s = r - alpha * v

rnorm = norm_inf(r);
rnorm = norm_inf(s);

if ( verbose > 2 && ParallelDescriptor::IOProcessor() )
{
Expand All @@ -174,15 +184,15 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }

sh.LocalCopy(r,0,0,ncomp,nghost);
sh.LocalCopy(s,0,0,ncomp,nghost);
Lp.apply(amrlev, mglev, t, sh, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, t);
//
// This is a little funky. I want to elide one of the reductions
// in the following two dotxy()s. We do that by calculating the "local"
// values and then reducing the two local values at the same time.
//
RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };
RT tvals[2] = { dotxy(t,t,true), dotxy(t,s,true) };

BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
Expand All @@ -197,7 +207,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
ret = 3; break;
}
MF::Saxpy(sol, omega, sh, 0, 0, ncomp, nghost); // sol += omega * sh
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r -= omega * t
MF::LinComb(r, RT(1.0), s, 0, -omega, t, 0, 0, ncomp, nghost); // r = s - omega * t

rnorm = norm_inf(r);

Expand All @@ -215,14 +225,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 4; break;
}
// rho_1 = rho; // no need to save old rho since we only need rhTv to compute beta

rho = dotxy(rh,r);

beta = rho / (rhTv * omega); // This is a result of alpha = rho_1 / rhTv

MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
rho_1 = rho;
}

if ( verbose > 0 )
Expand All @@ -241,9 +244,14 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
ret = 8;
}

if ( ( ret != 0 && ret != 8 ) || (rnorm > rnorm0) )
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
sol.LocalCopy(sorig,0,0,ncomp,nghost);
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
else
{
sol.setVal(RT(0.0));
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}

return ret;
Expand All @@ -262,6 +270,7 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
const auto& factory = sol.Factory();

MF p(ba, dm, ncomp, sol.nGrowVect(), MFInfo(), factory);
p.setVal(RT(0.0));

MF sorig(ba, dm, ncomp, nghost, MFInfo(), factory);
MF r (ba, dm, ncomp, nghost, MFInfo(), factory);
Expand All @@ -271,6 +280,8 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

sol.setVal(RT(0.0));

RT rnorm = norm_inf(r);
const RT rnorm0 = rnorm;

Expand All @@ -279,6 +290,7 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
amrex::Print() << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n';
}

RT rho_1 = 0;
int ret = 0;
iter = 1;

Expand All @@ -289,20 +301,30 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
<< ", rnorm = " << rnorm
<< ", eps_abs = " << eps_abs << std::endl;
}
sol.setVal(RT(0.0)); // Not sure why we do this in this case
return ret;
}

RT alpha, beta, rho, rho_1, pw;

rho = dotxy(r,r);
p.LocalCopy(r,0,0,ncomp,nghost);

for (; iter <= maxiter; ++iter)
{
RT rho = dotxy(z,r);

if ( rho == 0 )
{
ret = 1; break;
}
if (iter == 1)
{
p.LocalCopy(r,0,0,ncomp,nghost);
}
else
{
RT beta = rho/rho_1;
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
}
Lp.apply(amrlev, mglev, q, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);

pw = dotxy(p,q);
RT alpha;
RT pw = dotxy(p,q);
if ( pw != RT(0.0))
{
alpha = rho/pw;
Expand Down Expand Up @@ -333,16 +355,7 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }

if ( rho == 0 )
{
ret = 1; break;
}

rho_1 = rho;
rho = dotxy(r,r);

beta = rho/rho_1;
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
}

if ( verbose > 0 )
Expand All @@ -353,17 +366,22 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
<< rnorm/(rnorm0) << '\n';
}

if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
{
if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) {
amrex::Warning("MLCGSolver_cg: failed to converge!");
}
ret = 8;
}

if ( ( ret != 0 && ret != 8 ) || (rnorm > rnorm0) )
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
else
{
sol.LocalCopy(sorig,0,0,ncomp,nghost);
sol.setVal(RT(0.0));
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}

return ret;
Expand Down

0 comments on commit a3fc3e1

Please sign in to comment.