Skip to content

Commit

Permalink
tiny refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Sep 6, 2024
1 parent 4f81bba commit 36950c7
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 55 deletions.
2 changes: 1 addition & 1 deletion apps/ARAP/arap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ int main(int argc, char** argv)
rx.get_context(), weight_matrix, laplace_mat, constraints);

// pre_solve laplace_mat
laplace_mat.pre_solve(Solver::QR, PermuteMethod::NSTDIS);
laplace_mat.pre_solve(rx, Solver::QR, PermuteMethod::NSTDIS);

// launch box for rotation matrix calculation
rxmesh::LaunchBox<CUDABlockSize> lb_rot;
Expand Down
22 changes: 2 additions & 20 deletions apps/MCF/mcf_cusolver_chol.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
#include "rxmesh/matrix/sparse_matrix.cuh"
#include "rxmesh/rxmesh_static.h"

#include "rxmesh/matrix/mgnd_permute.cuh"
#include "rxmesh/matrix/nd_reorder.cuh"

#include "mcf_kernels.cuh"

template <typename T, uint32_t blockThreads>
Expand Down Expand Up @@ -195,24 +192,9 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
// A_mat.solve(B_mat, *X_mat, Solver::QR, PermuteMethod::NSTDIS);
// A_mat.solve(B_mat, *X_mat, Solver::CHOL, PermuteMethod::NSTDIS);

// Pre-Solves
std::vector<int> h_reorder_array;

if (permute_method == PermuteMethod::GPUMGND ||
permute_method == PermuteMethod::GPUND) {

// compute permutation
h_reorder_array.resize(rx.get_num_vertices());

// cuda_nd_reorder(rx, h_reorder_array, Arg.nd_level);

mgnd_permute(rx, h_reorder_array);
// pre-solve
A_mat.pre_solve(rx, Solver::CHOL, permute_method);

// Solving using CHOL
A_mat.pre_solve(Solver::CHOL, permute_method, h_reorder_array.data());
} else {
A_mat.pre_solve(Solver::CHOL, permute_method);
}

// Solve
A_mat.solve(B_mat, *X_mat);
Expand Down
4 changes: 2 additions & 2 deletions apps/NDReorder/test_all_permutations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void with_mgnd(rxmesh::RXMeshStatic& rx, const EigeMatT& eigen_mat)
{
std::vector<int> h_permute(eigen_mat.rows());

rxmesh::mgnd_permute(rx, h_permute);
rxmesh::mgnd_permute(rx, h_permute.data());

EXPECT_TRUE(
rxmesh::is_unique_permutation(h_permute.size(), h_permute.data()));
Expand All @@ -162,7 +162,7 @@ void with_cuda_nd(rxmesh::RXMeshStatic& rx, const EigeMatT& eigen_mat)

// rxmesh::cuda_nd_reorder(rx, h_permute, Arg.nd_level);

rxmesh::nd_permute(rx, h_permute);
rxmesh::nd_permute(rx, h_permute.data());

EXPECT_TRUE(
rxmesh::is_unique_permutation(h_permute.size(), h_permute.data()));
Expand Down
2 changes: 1 addition & 1 deletion apps/SCP/scp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ int main(int argc, char** argv)
uv_mat.fill_random();

// factorize the matrix
Lc.pre_solve(Solver::CHOL, PermuteMethod::NSTDIS);
Lc.pre_solve(rx, Solver::CHOL, PermuteMethod::NSTDIS);

// the power method
int iterations = 32;
Expand Down
6 changes: 2 additions & 4 deletions include/rxmesh/matrix/mgnd_permute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ __global__ static void assign_permutation(const Context context,
* is order last. h_permute should be allocated with size equal to num of
* vertices of the mesh.
*/
inline void mgnd_permute(const RXMeshStatic& rx, std::vector<int>& h_permute)
inline void mgnd_permute(const RXMeshStatic& rx, int* h_permute)
{
constexpr uint32_t blockThreads = 256;

h_permute.resize(rx.get_num_vertices());

// auto v_ordering = *rx.add_vertex_attribute<uint32_t>("v_ordering", 1);

int* d_permute = nullptr;
Expand Down Expand Up @@ -147,7 +145,7 @@ inline void mgnd_permute(const RXMeshStatic& rx, std::vector<int>& h_permute)
// h_permute[v_order_idx] = v_linea_id;
// });

CUDA_ERROR(cudaMemcpy(h_permute.data(),
CUDA_ERROR(cudaMemcpy(h_permute,
d_permute,
rx.get_num_vertices() * sizeof(int),
cudaMemcpyDeviceToHost));
Expand Down
13 changes: 6 additions & 7 deletions include/rxmesh/matrix/nd_permute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void construct_a_simple_chordal_graph(Graph<T>& graph)
*/
template <typename T>
void construct_patches_neighbor_graph(
RXMeshStatic& rx,
const RXMeshStatic& rx,
Graph<T>& patches_graph,
const std::vector<int>& h_patch_graph_edge_weight)
{
Expand Down Expand Up @@ -271,7 +271,7 @@ void heavy_max_matching(const RXMeshStatic& rx,

std::vector<integer_t> rands(graph.n);
fill_with_random_numbers(rands.data(), rands.size());
//fill_with_sequential_numbers(rands.data(), rands.size());
// fill_with_sequential_numbers(rands.data(), rands.size());


for (int k = 0; k < graph.n; ++k) {
Expand Down Expand Up @@ -612,7 +612,7 @@ void create_dfs_indexing(const int level,
}
}

void permute_separators(RXMeshStatic& rx,
void permute_separators(const RXMeshStatic& rx,
VertexAttribute<int>& v_index,
MaxMatchTree<int>& max_match_tree,
int* d_permute,
Expand Down Expand Up @@ -736,9 +736,8 @@ void permute_separators(RXMeshStatic& rx,
GPU_FREE(d_count);
}

void nd_permute(RXMeshStatic& rx, std::vector<int>& h_permute)
void nd_permute(RXMeshStatic& rx, int* h_permute)
{
h_permute.resize(rx.get_num_vertices());

auto v_index = *rx.add_vertex_attribute<int>("index", 1);

Expand Down Expand Up @@ -814,9 +813,9 @@ void nd_permute(RXMeshStatic& rx, std::vector<int>& h_permute)
timer.elapsed_millis(),
gtimer.elapsed_millis());

CUDA_ERROR(cudaMemcpy(h_permute.data(),
CUDA_ERROR(cudaMemcpy(h_permute,
d_permute,
h_permute.size() * sizeof(int),
rx.get_num_patches() * sizeof(int),
cudaMemcpyDeviceToHost));

GPU_FREE(d_permute);
Expand Down
30 changes: 13 additions & 17 deletions include/rxmesh/matrix/sparse_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "rxmesh/matrix/permute_util.h"
#include "rxmesh/matrix/sparse_matrix_kernels.cuh"

#include "rxmesh/matrix/mgnd_permute.cuh"
#include "rxmesh/matrix/nd_permute.cuh"

#include "rxmesh/launch_box.h"

#include <Eigen/Sparse>
Expand Down Expand Up @@ -802,8 +805,7 @@ struct SparseMatrix
* the solving process. Any other function call order would be undefined.
* @param reorder: the reorder method applied.
*/
__host__ void permute(PermuteMethod reorder,
IndexT* h_custom_reordering = nullptr)
__host__ void permute(RXMeshStatic& rx, PermuteMethod reorder)
{
permute_alloc(reorder);

Expand Down Expand Up @@ -843,17 +845,11 @@ struct SparseMatrix
m_h_solver_col_idx,
NULL,
m_h_permute));
} else if (reorder == PermuteMethod::GPUMGND ||
reorder == PermuteMethod::GPUND) {
if (h_custom_reordering == nullptr) {
RXMESH_ERROR(
"SparseMatrix::permute() CUSTOM reordering is specified "
"but no reordering array is provided!");
m_use_reorder = false;
return;
}
std::memcpy(
m_h_permute, h_custom_reordering, m_num_rows * sizeof(IndexT));
} else if (reorder == PermuteMethod::GPUMGND) {
mgnd_permute(rx, m_h_permute);

} else if (reorder == PermuteMethod::GPUND) {
nd_permute(rx, m_h_permute);
} else {
RXMESH_ERROR("SparseMatrix::permute() incompatible reorder method");
}
Expand Down Expand Up @@ -1328,9 +1324,9 @@ struct SparseMatrix
* sparse matrix before calling the solve() method below. After calling this
* pre_solve(), solver() can be called with multiple right hand sides
*/
__host__ void pre_solve(Solver solver,
PermuteMethod reorder = PermuteMethod::NSTDIS,
IndexT* h_custom_reordering = nullptr)
__host__ void pre_solve(RXMeshStatic& rx,
Solver solver,
PermuteMethod reorder = PermuteMethod::NSTDIS)
{
if (solver != Solver::CHOL && solver != Solver::QR) {
RXMESH_WARN(
Expand All @@ -1341,7 +1337,7 @@ struct SparseMatrix
m_current_solver = solver;

permute_alloc(reorder);
permute(reorder, h_custom_reordering);
permute(rx, reorder);
analyze_pattern(solver);
post_analyze_alloc(solver);
factorize(solver);
Expand Down
6 changes: 3 additions & 3 deletions tests/RXMesh_test/test_sparse_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ TEST(RXMeshStatic, SparseMatrixSimpleSolve)
// matrix.

using namespace rxmesh;

// generate rxmesh obj
std::string obj_path = rxmesh_args.obj_file_name;
RXMeshStatic rx(obj_path);
Expand Down Expand Up @@ -327,7 +327,7 @@ TEST(RXMeshStatic, SparseMatrixSimpleSolve)
TEST(RXMeshStatic, SparseMatrixLowerLevelAPISolve)
{
using namespace rxmesh;


// generate rxmesh obj
std::string obj_path = rxmesh_args.obj_file_name;
Expand Down Expand Up @@ -356,7 +356,7 @@ TEST(RXMeshStatic, SparseMatrixLowerLevelAPISolve)
rx.get_context(), *coords, A_mat, X_mat, B_mat, time_step);

// A_mat.solve(B_mat, X_mat, Solver::CHOL, PermuteMethod::NSTDIS);
A_mat.pre_solve(Solver::CHOL, PermuteMethod::NSTDIS);
A_mat.pre_solve(rx, Solver::CHOL, PermuteMethod::NSTDIS);
A_mat.solve(B_mat, X_mat);

A_mat.multiply(X_mat, ret_mat);
Expand Down

0 comments on commit 36950c7

Please sign in to comment.