Skip to content

Commit

Permalink
Fix sparse direct solver matrix construction when HypreParMatrix comm…
Browse files Browse the repository at this point in the history
…unicator is different than the solvers (for example, many processors with empty local matrices)
  • Loading branch information
sebastiangrimberg committed Jun 5, 2023
1 parent 7285a10 commit 64fa330
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 12 deletions.
47 changes: 41 additions & 6 deletions palace/linalg/strumpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ template <typename StrumpackSolverType>
StrumpackSolverBase<StrumpackSolverType>::StrumpackSolverBase(
MPI_Comm comm, int sym_fact_type, strumpack::CompressionType comp_type, double lr_tol,
int butterfly_l, int lossy_prec, int print)
: StrumpackSolverType(comm)
: StrumpackSolverType(comm), comm(comm)
{
// Configure the solver.
this->SetPrintFactorStatistics(print > 1);
Expand Down Expand Up @@ -71,18 +71,53 @@ template <typename StrumpackSolverType>
void StrumpackSolverBase<StrumpackSolverType>::SetOperator(const Operator &op)
{
// Convert the input operator to a distributed STRUMPACK matrix (always assume a symmetric
// sparsity pattern). Safe to delete the matrix since STRUMPACK copies it on input.
// sparsity pattern). This is very similar to the MFEM STRUMPACKRowLocMatrix from a
// HypreParMatrix but avoids using the communicator from the Hypre matrix in the case that
// the solver is constructed on a different communicator.
const mfem::HypreParMatrix *hypA;
const auto *PtAP = dynamic_cast<const ParOperator *>(&op);
if (PtAP)
{
mfem::STRUMPACKRowLocMatrix A(const_cast<ParOperator *>(PtAP)->ParallelAssemble());
StrumpackSolverType::SetOperator(A);
hypA = &const_cast<ParOperator *>(PtAP)->ParallelAssemble();
}
else
{
mfem::STRUMPACKRowLocMatrix A(op, true);
StrumpackSolverType::SetOperator(A);
hypA = dynamic_cast<const mfem::HypreParMatrix *>(&op);
MFEM_VERIFY(hypA, "StrumpackSolver requires a HypreParMatrix operator!");
}
hypre_ParCSRMatrix *parcsr =
(hypre_ParCSRMatrix *)const_cast<mfem::HypreParMatrix &>(*hypA);
hypA->HostRead();
hypre_CSRMatrix *csr = hypre_MergeDiagAndOffd(parcsr);
hypA->HypreRead();

// Create the STRUMPACKRowLocMatrix by taking the internal data from a hypre_CSRMatrix.
HYPRE_Int n_loc = csr->num_rows;
HYPRE_BigInt first_row = parcsr->first_row_index;
HYPRE_Int *I = csr->i;
HYPRE_BigInt *J = csr->big_j;
double *data = csr->data;

// Safe to delete the matrix since STRUMPACK copies it on input. Also clean up the Hypre
// data structure once we are done with it.
#if !defined(HYPRE_BIGINT)
mfem::STRUMPACKRowLocMatrix A(comm, n_loc, first_row, hypA->GetGlobalNumRows(),
hypA->GetGlobalNumCols(), I, J, data, true);
#else
int n_loc_int = static_cast<int>(n_loc);
MFEM_ASSERT(n_loc == (HYPRE_Int)n_loc_int,
"Overflow error for local sparse matrix size!");
mfem::Array<int> II(n_loc_int + 1);
for (int i = 0; i <= n_loc_int; i++)
{
II[i] = static_cast<int>(I[i]);
MFEM_ASSERT(I[i] == (HYPRE_Int)II[i], "Overflow error for local sparse matrix index!");
}
mfem::STRUMPACKRowLocMatrix A(comm, n_loc_int, first_row, hypA->GetGlobalNumRows(),
hypA->GetGlobalNumCols(), II, J, data, true);
#endif
StrumpackSolverType::SetOperator(A);
hypre_CSRMatrixDestroy(csr);
}

template class StrumpackSolverBase<mfem::STRUMPACKSolver>;
Expand Down
2 changes: 2 additions & 0 deletions palace/linalg/strumpack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ template <typename StrumpackSolverType>
class StrumpackSolverBase : public StrumpackSolverType
{
private:
MPI_Comm comm;

strumpack::CompressionType CompressionType(config::LinearSolverData::CompressionType type)
{
switch (type)
Expand Down
49 changes: 43 additions & 6 deletions palace/linalg/superlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int GetNpDep(int np, bool use_3d)
} // namespace

SuperLUSolver::SuperLUSolver(MPI_Comm comm, int sym_fact_type, bool use_3d, int print)
: mfem::Solver(), solver(comm, GetNpDep(Mpi::Size(comm), use_3d))
: mfem::Solver(), comm(comm), A(nullptr), solver(comm, GetNpDep(Mpi::Size(comm), use_3d))
{
// Configure the solver.
if (print > 1)
Expand Down Expand Up @@ -74,25 +74,62 @@ SuperLUSolver::SuperLUSolver(MPI_Comm comm, int sym_fact_type, bool use_3d, int

void SuperLUSolver::SetOperator(const Operator &op)
{
// We need to save A because SuperLU does not copy the input matrix. For repeated
// factorizations, always reuse the sparsity pattern.
// For repeated factorizations, always reuse the sparsity pattern. This is very similar to
// the MFEM SuperLURowLocMatrix from a HypreParMatrix but avoids using the communicator
// from the Hypre matrix in the case that the solver is constructed on a different
// communicator.
if (A)
{
solver.SetFact(mfem::superlu::SamePattern_SameRowPerm);
}
const mfem::HypreParMatrix *hypA;
const auto *PtAP = dynamic_cast<const ParOperator *>(&op);
if (PtAP)
{
A = std::make_unique<mfem::SuperLURowLocMatrix>(
const_cast<ParOperator *>(PtAP)->ParallelAssemble());
hypA = &const_cast<ParOperator *>(PtAP)->ParallelAssemble();
}
else
{
A = std::make_unique<mfem::SuperLURowLocMatrix>(op);
hypA = dynamic_cast<const mfem::HypreParMatrix *>(&op);
MFEM_VERIFY(hypA, "SuperLUSolver requires a HypreParMatrix operator!");
}
hypre_ParCSRMatrix *parcsr =
(hypre_ParCSRMatrix *)const_cast<mfem::HypreParMatrix &>(*hypA);
hypA->HostRead();
hypre_CSRMatrix *csr = hypre_MergeDiagAndOffd(parcsr);
hypA->HypreRead();

// Create the SuperLURowLocMatrix by taking the internal data from a hypre_CSRMatrix.
HYPRE_Int n_loc = csr->num_rows;
HYPRE_BigInt first_row = parcsr->first_row_index;
HYPRE_Int *I = csr->i;
HYPRE_BigInt *J = csr->big_j;
double *data = csr->data;

// We need to save A because SuperLU does not copy the input matrix. Also clean up the
// Hypre data structure once we are done with it.
#if !defined(HYPRE_BIGINT)
A = std::make_unique<mfem::SuperLURowLocMatrix>(comm, n_loc, first_row,
hypA->GetGlobalNumRows(),
hypA->GetGlobalNumCols(), I, J, data);
#else
int n_loc_int = static_cast<int>(n_loc);
MFEM_ASSERT(n_loc == (HYPRE_Int)n_loc_int,
"Overflow error for local sparse matrix size!");
mfem::Array<int> II(n_loc_int + 1);
for (int i = 0; i <= n_loc_int; i++)
{
II[i] = static_cast<int>(I[i]);
MFEM_ASSERT(I[i] == (HYPRE_Int)II[i], "Overflow error for local sparse matrix index!");
}
A = std::make_unique<mfem::SuperLURowLocMatrix>(comm, n_loc_int, first_row,
hypA->GetGlobalNumRows(),
hypA->GetGlobalNumCols(), II, J, data);
#endif
solver.SetOperator(*A);
height = solver.Height();
width = solver.Width();
hypre_CSRMatrixDestroy(csr);
}

} // namespace palace
Expand Down
1 change: 1 addition & 0 deletions palace/linalg/superlu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace palace
class SuperLUSolver : public mfem::Solver
{
private:
MPI_Comm comm;
std::unique_ptr<mfem::SuperLURowLocMatrix> A;
mfem::SuperLUSolver solver;

Expand Down

0 comments on commit 64fa330

Please sign in to comment.