Skip to content

Commit

Permalink
mcf
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Sep 6, 2024
1 parent 9c06d10 commit 7694ec7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
30 changes: 18 additions & 12 deletions apps/MCF/mcf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ struct arg
std::string obj_file_name = STRINGIFY(INPUT_DIR) "dragon.obj";
std::string output_folder = STRINGIFY(OUTPUT_DIR);
std::string perm_method = "gpund";
std::string solver = "chol";
uint32_t device_id = 0;
float time_step = 0.001;
float cg_tolerance = 1e-6;
uint32_t max_num_cg_iter = 1000;
bool use_uniform_laplace = false;
uint32_t nd_level = 1;
bool use_uniform_laplace = false;
char** argv;
int argc;
} Arg;
Expand All @@ -43,11 +43,11 @@ TEST(App, MCF)

ASSERT_TRUE(rx.is_edge_manifold());

// RXMesh Impl
// mcf_cg<dataT>(rx);

// RXMesh cusolver Impl
mcf_cusolver_chol<dataT>(rx, string_to_permute_method(Arg.perm_method));
if (Arg.solver == "ch"){
mcf_cg<dataT>(rx);
} else{
mcf_cusolver_chol<dataT>(rx, string_to_permute_method(Arg.perm_method));
}
}

int main(int argc, char** argv)
Expand All @@ -70,11 +70,12 @@ int main(int argc, char** argv)
" -uniform_laplace: Use uniform Laplace weights. Default is {} \n"
" -dt: Time step (delta t). Default is {} \n"
" Hint: should be between (0.001, 1) for cotan Laplace or between (1, 100) for uniform Laplace\n"
" -solver: Solver to use. Options are CG or Chol. Default is {}\n"
" -eps: Conjugate gradient tolerance. Default is {}\n"
" -max_cg_iter: Conjugate gradient maximum number of iterations. Default is {}\n"
" -nd_level: ND level. Default is {}\n"
" -perm: Permutation method for Cholesky factorization. Default is {}\n"
" -max_cg_iter: Conjugate gradient maximum number of iterations. Default is {}\n"
" -device_id: GPU device ID. Default is {}",
Arg.obj_file_name, Arg.output_folder, (Arg.use_uniform_laplace? "true" : "false"), Arg.time_step, Arg.cg_tolerance, Arg.max_num_cg_iter, Arg.device_id);
Arg.obj_file_name, Arg.output_folder, (Arg.use_uniform_laplace? "true" : "false"), Arg.solver, Arg.time_step, Arg.cg_tolerance, Arg.perm_method, Arg.max_num_cg_iter, Arg.device_id);
// clang-format on
exit(EXIT_SUCCESS);
}
Expand Down Expand Up @@ -106,13 +107,18 @@ int main(int argc, char** argv)
Arg.device_id =
atoi(get_cmd_option(argv, argv + argc, "-device_id"));
}
if (cmd_option_exists(argv, argc + argv, "-nd_level")) {
Arg.nd_level = atoi(get_cmd_option(argv, argv + argc, "-nd_level"));
if (cmd_option_exists(argv, argc + argv, "-perm_method")) {
Arg.perm_method = std::string(get_cmd_option(argv, argv + argc, "-perm_method"));
}
if (cmd_option_exists(argv, argc + argv, "-solver")) {
Arg.solver = std::string(get_cmd_option(argv, argv + argc, "-solver"));
}
}

RXMESH_TRACE("input= {}", Arg.obj_file_name);
RXMESH_TRACE("output_folder= {}", Arg.output_folder);
RXMESH_TRACE("solver= {}", Arg.solver);
RXMESH_TRACE("perm_method= {}", Arg.perm_method);
RXMESH_TRACE("max_num_cg_iter= {}", Arg.max_num_cg_iter);
RXMESH_TRACE("cg_tolerance= {0:f}", Arg.cg_tolerance);
RXMESH_TRACE("use_uniform_laplace= {}", Arg.use_uniform_laplace);
Expand Down
11 changes: 11 additions & 0 deletions apps/MCF/mcf_cusolver_chol.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
report.add_member("blockThreads", blockThreads);
report.add_member("PermuteMethod",
permute_method_to_string(permute_method));

RXMESH_INFO("permute_method took {}", permute_method_to_string(permute_method));

float total_time = 0;

CPUTimer timer;
GPUTimer gtimer;
Expand All @@ -220,6 +224,7 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("permute_alloc", timer.elapsed_millis());
total_time+=timer.elapsed_millis();

timer.start();
gtimer.start();
Expand All @@ -230,6 +235,7 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("permute", timer.elapsed_millis());
total_time+=timer.elapsed_millis();


timer.start();
Expand All @@ -241,6 +247,7 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("analyze_pattern", timer.elapsed_millis());
total_time+=timer.elapsed_millis();


timer.start();
Expand All @@ -252,6 +259,7 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("post_analyze_alloc", timer.elapsed_millis());
total_time+=timer.elapsed_millis();


timer.start();
Expand All @@ -263,6 +271,7 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("factorize", timer.elapsed_millis());
total_time+=timer.elapsed_millis();


timer.start();
Expand All @@ -274,7 +283,9 @@ void mcf_cusolver_chol(rxmesh::RXMeshStatic& rx,
timer.elapsed_millis(),
gtimer.elapsed_millis());
report.add_member("solve", timer.elapsed_millis());
total_time+=timer.elapsed_millis();

report.add_member("total_time", total_time);

// move the results to the host
// if we use LU, the data will be on the host and we should not move the
Expand Down

0 comments on commit 7694ec7

Please sign in to comment.